From 0b72a8dc5c056fabde9fe247fb7e1916503d68ee Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 31 May 2024 11:04:00 -0500 Subject: [PATCH 01/14] WIP: No Orphans downstream --- docs/src/misc/mixin.md | 37 ++++++---- notebooks/01_Insert_Data.ipynb | 2 +- notebooks/03_Merge_Tables.ipynb | 10 +-- notebooks/py_scripts/01_Insert_Data.py | 2 +- notebooks/py_scripts/03_Merge_Tables.py | 12 ++-- src/spyglass/utils/dj_graph.py | 75 ++++++++++++++----- src/spyglass/utils/dj_merge_tables.py | 6 +- src/spyglass/utils/dj_mixin.py | 96 ++++++++++++------------- tests/utils/test_mixin.py | 4 +- 9 files changed, 145 insertions(+), 99 deletions(-) diff --git a/docs/src/misc/mixin.md b/docs/src/misc/mixin.md index 229747402..5cd5c16c2 100644 --- a/docs/src/misc/mixin.md +++ b/docs/src/misc/mixin.md @@ -131,29 +131,38 @@ masters, or null entry masters without matching data. For [Merge tables](./merge_tables.md), this is a significant problem. If a user wants to delete all entries associated with a given session, she must find all -Merge entries and delete them in the correct order. The mixin provides a -function, `delete_downstream_merge`, to handle this, which is run by default -when calling `delete`. +part table entries, including Merge tables, and delete them in the correct +order. The mixin provides a function, `delete_downstream_parts`, to handle this, +which is run by default when calling `delete`. -`delete_downstream_merge`, also aliased as `ddm`, identifies all Merge tables -downsteam of where it is called. If `dry_run=True`, it will return a list of -entries that would be deleted, otherwise it will delete them. +`delete_downstream_parts`, also aliased as `ddp`, identifies all part tables +with foreign key references downsteam of where it is called. If `dry_run=True`, +it will return a list of entries that would be deleted, otherwise it will delete +them. -Importantly, `delete_downstream_merge` cannot properly interact with tables that +Importantly, `delete_downstream_parts` cannot properly interact with tables that have not been imported into the current namespace. If you are having trouble with part deletion errors, import the offending table and rerun the function with `reload_cache=True`. ```python +import datajoint as dj from spyglass.common import Nwbfile restricted_nwbfile = Nwbfile() & "nwb_file_name LIKE 'Name%'" -restricted_nwbfile.delete_downstream_merge(dry_run=False) -# DataJointError("Attempt to delete part table MyMerge.Part before ... + +vanilla_dj_table = dj.FreeTable(dj.conn(), Nwbfile.full_table_name) +vanilla_dj_table.delete() +# DataJointError("Attempt to delete part table MyMerge.Part before ... ") + +restricted_nwbfile.delete() +# [WARNING] Spyglass: No part deletes found w/ Nwbfile ... +# OR +# ValueError("Please import MyMerge and try again.") from spyglass.example import MyMerge -restricted_nwbfile.delete_downstream_merge(reload_cache=True, dry_run=False) +restricted_nwbfile.delete_downstream_parts(reload_cache=True, dry_run=False) ``` Because each table keeps a cache of downsteam merge tables, it is important to @@ -164,13 +173,13 @@ Speed gains can also be achieved by avoiding re-instancing the table each time. # Slow from spyglass.common import Nwbfile -(Nwbfile() & "nwb_file_name LIKE 'Name%'").ddm(dry_run=False) -(Nwbfile() & "nwb_file_name LIKE 'Other%'").ddm(dry_run=False) +(Nwbfile() & "nwb_file_name LIKE 'Name%'").ddp(dry_run=False) +(Nwbfile() & "nwb_file_name LIKE 'Other%'").ddp(dry_run=False) # Faster from spyglass.common import Nwbfile nwbfile = Nwbfile() -(nwbfile & "nwb_file_name LIKE 'Name%'").ddm(dry_run=False) -(nwbfile & "nwb_file_name LIKE 'Other%'").ddm(dry_run=False) +(nwbfile & "nwb_file_name LIKE 'Name%'").ddp(dry_run=False) +(nwbfile & "nwb_file_name LIKE 'Other%'").ddp(dry_run=False) ``` diff --git a/notebooks/01_Insert_Data.ipynb b/notebooks/01_Insert_Data.ipynb index 2a2297642..124da3d1c 100644 --- a/notebooks/01_Insert_Data.ipynb +++ b/notebooks/01_Insert_Data.ipynb @@ -2134,7 +2134,7 @@ "```python\n", "nwbfile = sgc.Nwbfile()\n", "\n", - "(nwbfile & {\"nwb_file_name\": nwb_copy_file_name}).delete_downstream_merge(\n", + "(nwbfile & {\"nwb_file_name\": nwb_copy_file_name}).delete_downstream_parts(\n", " dry_run=False, # True will show Merge Table entries that would be deleted\n", ")\n", "```\n", diff --git a/notebooks/03_Merge_Tables.ipynb b/notebooks/03_Merge_Tables.ipynb index 6adbbd5bf..0f9b2c3c2 100644 --- a/notebooks/03_Merge_Tables.ipynb +++ b/notebooks/03_Merge_Tables.ipynb @@ -90,7 +90,7 @@ "import spyglass.common as sgc\n", "import spyglass.lfp as lfp\n", "from spyglass.utils.nwb_helper_fn import get_nwb_copy_filename\n", - "from spyglass.utils.dj_merge_tables import delete_downstream_merge, Merge\n", + "from spyglass.utils.dj_merge_tables import delete_downstream_parts, Merge\n", "from spyglass.common.common_ephys import LFP as CommonLFP # Upstream 1\n", "from spyglass.lfp.lfp_merge import LFPOutput # Merge Table\n", "from spyglass.lfp.v1.lfp import LFPV1 # Upstream 2" @@ -955,8 +955,8 @@ "2. use `merge_delete_parent` to delete from the parent sources, getting rid of\n", " the entries in the source table they came from.\n", "\n", - "3. use `delete_downstream_merge` to find Merge Tables downstream of any other\n", - " table and get rid full entries, avoiding orphaned master table entries.\n", + "3. use `delete_downstream_parts` to find downstream part tables, like Merge \n", + " Tables, and get rid full entries, avoiding orphaned master table entries.\n", "\n", "The two latter cases can be destructive, so we include an extra layer of\n", "protection with `dry_run`. When true (by default), these functions return\n", @@ -1016,7 +1016,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "`delete_downstream_merge` is available from any other table in the pipeline,\n", + "`delete_downstream_parts` is available from any other table in the pipeline,\n", "but it does take some time to find the links downstream. If you're using this,\n", "you can save time by reassigning your table to a variable, which will preserve\n", "a copy of the previous search.\n", @@ -1056,7 +1056,7 @@ "source": [ "nwbfile = sgc.Nwbfile()\n", "\n", - "(nwbfile & nwb_file_dict).delete_downstream_merge(\n", + "(nwbfile & nwb_file_dict).delete_downstream_parts(\n", " dry_run=True,\n", " reload_cache=False, # if still encountering errors, try setting this to True\n", ")" diff --git a/notebooks/py_scripts/01_Insert_Data.py b/notebooks/py_scripts/01_Insert_Data.py index 870c6907a..a0a84b828 100644 --- a/notebooks/py_scripts/01_Insert_Data.py +++ b/notebooks/py_scripts/01_Insert_Data.py @@ -378,7 +378,7 @@ # ```python # nwbfile = sgc.Nwbfile() # -# (nwbfile & {"nwb_file_name": nwb_copy_file_name}).delete_downstream_merge( +# (nwbfile & {"nwb_file_name": nwb_copy_file_name}).delete_downstream_parts( # dry_run=False, # True will show Merge Table entries that would be deleted # ) # ``` diff --git a/notebooks/py_scripts/03_Merge_Tables.py b/notebooks/py_scripts/03_Merge_Tables.py index ac3ad4e69..690bc7834 100644 --- a/notebooks/py_scripts/03_Merge_Tables.py +++ b/notebooks/py_scripts/03_Merge_Tables.py @@ -5,7 +5,7 @@ # extension: .py # format_name: light # format_version: '1.5' -# jupytext_version: 1.15.2 +# jupytext_version: 1.16.0 # kernelspec: # display_name: spy # language: python @@ -64,7 +64,7 @@ import spyglass.common as sgc import spyglass.lfp as lfp from spyglass.utils.nwb_helper_fn import get_nwb_copy_filename -from spyglass.utils.dj_merge_tables import delete_downstream_merge, Merge +from spyglass.utils.dj_merge_tables import delete_downstream_parts, Merge from spyglass.common.common_ephys import LFP as CommonLFP # Upstream 1 from spyglass.lfp.lfp_merge import LFPOutput # Merge Table from spyglass.lfp.v1.lfp import LFPV1 # Upstream 2 @@ -192,8 +192,8 @@ # 2. use `merge_delete_parent` to delete from the parent sources, getting rid of # the entries in the source table they came from. # -# 3. use `delete_downstream_merge` to find Merge Tables downstream of any other -# table and get rid full entries, avoiding orphaned master table entries. +# 3. use `delete_downstream_parts` to find downstream part tables, like Merge +# Tables, and get rid full entries, avoiding orphaned master table entries. # # The two latter cases can be destructive, so we include an extra layer of # protection with `dry_run`. When true (by default), these functions return @@ -204,7 +204,7 @@ LFPOutput.merge_delete_parent(restriction=nwb_file_dict, dry_run=True) -# `delete_downstream_merge` is available from any other table in the pipeline, +# `delete_downstream_parts` is available from any other table in the pipeline, # but it does take some time to find the links downstream. If you're using this, # you can save time by reassigning your table to a variable, which will preserve # a copy of the previous search. @@ -216,7 +216,7 @@ # + nwbfile = sgc.Nwbfile() -(nwbfile & nwb_file_dict).delete_downstream_merge( +(nwbfile & nwb_file_dict).delete_downstream_parts( dry_run=True, reload_cache=False, # if still encountering errors, try setting this to True ) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 5bf3d25d0..4526fa703 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -91,11 +91,10 @@ def __init__(self, seed_table: Table, verbose: bool = False, **kwargs): self.seed_table = seed_table self.connection = seed_table.connection - # Undirected graph may not be needed, but adding FT to the graph - # prevents `to_undirected` from working. If using undirected, remove - # PERIPHERAL_TABLES from the graph. self.graph = seed_table.connection.dependencies self.graph.load() + # undirect not needed in all cases but need to do before adding ft nodes + self.undirect_graph = self.graph.to_undirected() self.verbose = verbose self.leaves = set() @@ -243,6 +242,11 @@ def _and_parts(self, table): ret.extend(parts) return ret + def _ignore_peripheral(self): + """Ignore peripheral tables in graph traversal.""" + self.no_visit.update(PERIPHERAL_TABLES) + self.undirect_graph.remove_nodes_from(PERIPHERAL_TABLES) + # ---------------------------- Graph Traversal ----------------------------- def _bridge_restr( @@ -438,6 +442,47 @@ def all_ft(self): ] return [ft for ft in all_ft if len(ft) > 0] + def ft_from_list( + self, + tables: List[str], + with_restr: bool = True, + sort_from: str = None, + ) -> List[FreeTable]: + """Return non-empty FreeTable objects from list of table names. + + Parameters + ---------- + tables : List[str] + List of table names + with_restr : bool, optional + Restrict FreeTable to restriction. Default True. + sort_from : str, optional + Table name. Sort by distance from this table. Default None, no sort. + """ + + def graph_distance(self, table1: str = None, table2: str = None) -> int: + """Sort tables by distance from root. If no root, do nothing.""" + if not table1 or not table2: + return 0 + try: + return len(shortest_path(self.undirect_graph, table1, table2)) + except (NodeNotFound, NetworkXNoPath): + return 99 + + self.cascade() + tables = [self._ensure_name(t) for t in tables] + + if sort_from: + sort_from = self._ensure_name(sort_from) + tables = sorted( + tables, + key=lambda t: graph_distance(sort_from, t), + ) + + fts = [self._get_ft(table, with_restr=with_restr) for table in tables] + + return [ft for ft in fts if len(ft) > 0] + @property def as_dict(self) -> List[Dict[str, str]]: """Return as a list of dictionaries of table_name: restriction""" @@ -459,6 +504,7 @@ def __init__( direction: Direction = "up", cascade: bool = False, verbose: bool = False, + ignore_peripheral: bool = False, **kwargs, ): """Use graph to cascade restrictions up from leaves to all ancestors. @@ -487,6 +533,9 @@ def __init__( Default False verbose : bool, optional Whether to print verbose output. Default False + ignore_peripheral : bool, optional + Whether to ignore peripheral tables in graph traversal. Default + False """ super().__init__(seed_table, verbose=verbose) @@ -495,6 +544,8 @@ def __init__( ) self.add_leaves(leaves) + if ignore_peripheral: + self._ignore_peripheral() if cascade: self.cascade(direction=direction) @@ -832,7 +883,7 @@ def __init__( seed_table = parent if isinstance(parent, Table) else child super().__init__(seed_table=seed_table, verbose=verbose) - self.no_visit.update(PERIPHERAL_TABLES) + self._ignore_peripheral() self.no_visit.update(self._ensure_name(banned_tables) or []) self.no_visit.difference_update([self.parent, self.child]) self.searched_tables = set() @@ -900,20 +951,6 @@ def has_link(self) -> bool: _ = self.path return self.link_type is not None - @cached_property - def all_ft(self) -> List[dj.FreeTable]: - """Return list of FreeTable objects for each table in chain. - - Unused. Preserved for future debugging. - """ - if not self.has_link: - return None - return [ - self._get_ft(table, with_restr=False) - for table in self.path - if not table.isnumeric() - ] - @property def path_str(self) -> str: if not self.path: @@ -1058,8 +1095,8 @@ def find_path(self, directed=True) -> List[str]: search_graph = self.graph if not directed: + # FTs in self.graph prevent `to_undirected` from working. self.connection.dependencies.load() - self.undirect_graph = self.connection.dependencies.to_undirected() search_graph = self.undirect_graph search_graph.remove_nodes_from(self.no_visit) diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 37a51b674..5b667dbe6 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -820,11 +820,11 @@ def delete_downstream_merge( ) -> list: """Given a table/restriction, id or delete relevant downstream merge entries - Passthrough to SpyglassMixin.delete_downstream_merge + Passthrough to SpyglassMixin.delete_downstream_parts """ logger.warning( "DEPRECATED: This function will be removed in `0.6`. " - + "Use AnyTable().delete_downstream_merge() instead." + + "Use AnyTable().delete_downstream_parts() instead." ) from spyglass.utils.dj_mixin import SpyglassMixin @@ -833,4 +833,4 @@ def delete_downstream_merge( raise ValueError("Input must be a Spyglass Table.") table = table if isinstance(table, dj.Table) else table() - return table.delete_downstream_merge(**kwargs) + return table.delete_downstream_parts(**kwargs) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 35e54ea7a..19227b567 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -51,8 +51,8 @@ class SpyglassMixin: `restriction` can be set to a string to restrict the delete. `dry_run` can be set to False to commit the delete. `reload_cache` can be set to True to reload the merge cache. - ddm(*args, **kwargs) - Alias for delete_downstream_merge. + ddp(*args, **kwargs) + Alias for delete_downstream_parts cautious_delete(force_permission=False, *args, **kwargs) Check user permissions before deleting table rows. Permission is granted to users listed as admin in LabMember table or to users on a team with @@ -239,10 +239,10 @@ def fetch_pynapple(self, *attrs, **kwargs): for file_name in nwb_files ] - # ------------------------ delete_downstream_merge ------------------------ + # ------------------------ delete_downstream_parts ------------------------ def _import_merge_tables(self): - """Import all merge tables downstream of self.""" + """Import all merge tables.""" from spyglass.decoding.decoding_merge import DecodingOutput # noqa F401 from spyglass.lfp.lfp_merge import LFPOutput # noqa F401 from spyglass.linearization.merge import ( @@ -262,27 +262,26 @@ def _import_merge_tables(self): ) @cached_property - def _merge_tables(self) -> Dict[str, dj.FreeTable]: - """Dict of merge tables downstream of self: {full_table_name: FreeTable}. + def _part_masters(self) -> Dict[str, dj.FreeTable]: + """Dict of part tables downstream of self: {camel_name: FreeTable}. - Cache of items in parents of self.descendants(as_objects=True). Both - descendant and parent must have the reserved primary key 'merge_id'. + Cache of masters of self.descendants(as_objects=True). + Part must have other parent(s) besides master. """ self.connection.dependencies.load() - merge_tables = {} + part_masters = {} visited = set() def search_descendants(parent): for desc in parent.descendants(as_objects=True): - if ( - MERGE_PK not in desc.heading.names - or not (master_name := get_master(desc.full_table_name)) - or master_name in merge_tables + if ( # Check if has master, no other fk, or already in cache + not (master_name := get_master(desc.full_table_name)) + or not set(desc.parents()) - set([master_name]) + or master_name in part_masters ): continue master_ft = dj.FreeTable(self.connection, master_name) - if is_merge_table(master_ft): - merge_tables[master_name] = master_ft + part_masters[to_camel_case(master_name)] = master_ft if master_name not in visited: visited.add(master_name) search_descendants(master_ft) @@ -298,36 +297,34 @@ def search_descendants(parent): raise ValueError(f"Please import {table_name} and try again.") logger.info( - f"Building merge cache for {self.camel_name}.\n\t" - + f"Found {len(merge_tables)} downstream merge tables" + f"Building part-parent cache for {self.camel_name}.\n\t" + + f"Found {len(part_masters)} downstream merge tables" ) - return merge_tables + return part_masters @cached_property def _merge_chains(self) -> OrderedDict[str, List[dj.FreeTable]]: - """Dict of chains to merges downstream of self + """Dict of chains to parts downstream of self - Format: {full_table_name: TableChains}. + Format: {camel_name: TableChains}. - For each merge table found in _merge_tables, find the path from self to - merge via merge parts. If the path is valid, add it to the dict. Cache - prevents need to recompute whenever delete_downstream_merge is called + For each table found in _part_masters, find the path from self to + master via merge parts. If the path is valid, add it to the dict. Cache + prevents need to recompute whenever delete_downstream_part is called with a new restriction. To recompute, add `reload_cache=True` to - delete_downstream_merge call. + delete_downstream_part call. """ from spyglass.utils.dj_graph import TableChains # noqa F401 merge_chains = {} - for name, merge_table in self._merge_tables.items(): + for name, merge_table in self._part_masters.items(): chains = TableChains(self, merge_table) if len(chains): merge_chains[name] = chains - # This is ordered by max_len of chain from self to merge, which assumes - # that the merge table with the longest chain is the most downstream. - # A more sophisticated approach would order by length from self to - # each merge part independently, but this is a good first approximation. + # self->master chains ordered by max_len as a proxy for downstream-ness. + # A better approach would sort each self->part independently. return OrderedDict( sorted( @@ -342,7 +339,7 @@ def _get_chain(self, substring): return chain raise ValueError(f"No chain found with '{substring}' in name.") - def _commit_merge_deletes( + def _commit_part_deletes( self, merge_join_dict: Dict[str, List[QueryExpression]], **kwargs ) -> None: """Commit merge deletes. @@ -353,13 +350,13 @@ def _commit_merge_deletes( Dictionary of merge tables and their joins. Uses 'merge_id' primary key to restrict delete. - Extracted for use in cautious_delete and delete_downstream_merge.""" + Extracted for use in cautious_delete and delete_downstream_parts.""" for table_name, part_restr in merge_join_dict.items(): - table = self._merge_tables[table_name] - keys = [part.fetch(MERGE_PK, as_dict=True) for part in part_restr] + table = self._part_masters[table_name] + keys = [part.proj().fetch(as_dict=True) for part in part_restr] (table & keys).delete(**kwargs) - def delete_downstream_merge( + def delete_downstream_parts( self, restriction: str = None, dry_run: bool = True, @@ -403,7 +400,7 @@ def delete_downstream_merge( if not merge_join_dict and not disable_warning: logger.warning( - f"No merge deletes found w/ {self.camel_name} & " + f"No part deletes found w/ {self.camel_name} & " + f"{restriction}.\n\tIf this is unexpected, try importing " + " Merge table(s) and running with `reload_cache`." ) @@ -411,9 +408,9 @@ def delete_downstream_merge( if dry_run: return merge_join_dict.values() if return_parts else merge_join_dict - self._commit_merge_deletes(merge_join_dict, **kwargs) + self._commit_part_deletes(merge_join_dict, **kwargs) - def ddm( + def ddp( self, restriction: str = None, dry_run: bool = True, @@ -423,8 +420,8 @@ def ddm( *args, **kwargs, ) -> Union[List[QueryExpression], Dict[str, List[QueryExpression]]]: - """Alias for delete_downstream_merge.""" - return self.delete_downstream_merge( + """Alias for delete_downstream_parts.""" + return self.delete_downstream_parts( restriction=restriction, dry_run=dry_run, reload_cache=reload_cache, @@ -609,7 +606,7 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): if not force_permission: self._check_delete_permission() - merge_deletes = self.delete_downstream_merge( + merge_deletes = self.delete_downstream_parts( dry_run=True, disable_warning=True, return_parts=False, @@ -630,7 +627,7 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): or not safemode or user_choice("Commit deletes?", default="no") == "yes" ): - self._commit_merge_deletes(merge_deletes, **kwargs) + self._commit_part_deletes(merge_deletes, **kwargs) else: logger.info("Delete aborted.") self._log_delete(start) @@ -856,8 +853,8 @@ def restrict_by( Returns ------- - Union[QueryExpression, FindKeyGraph] - Restricted version of present table or FindKeyGraph object. If + Union[QueryExpression, TableChain] + Restricted version of present table or TableChain object. If return_graph, use all_ft attribute to see all tables in cascade. """ from spyglass.utils.dj_graph import TableChain # noqa: F401 @@ -897,11 +894,14 @@ def restrict_by( return graph ret = self & graph._get_restr(self.full_table_name) - if len(ret) == len(self) or len(ret) == 0: - logger.warning( - f"Failed to restrict with path: {graph.path_str}\n\t" - + "See `help(YourTable.restrict_by)`" - ) + warn_text = ( + f" after restrict with path: {graph.path_str}\n\t " + + "See `help(YourTable.restrict_by)`" + ) + if len(ret) == len(self): + logger.warning("Same length" + warn_text) + elif len(ret) == 0: + logger.warning("No entries" + warn_text) return ret diff --git a/tests/utils/test_mixin.py b/tests/utils/test_mixin.py index 010abf03c..c70c67b13 100644 --- a/tests/utils/test_mixin.py +++ b/tests/utils/test_mixin.py @@ -64,10 +64,10 @@ def test_get_chain(Nwbfile, pos_merge_tables): @pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") def test_ddm_warning(Nwbfile, caplog): """Test that the mixin warns on empty delete_downstream_merge.""" - (Nwbfile.file_like("BadName")).delete_downstream_merge( + (Nwbfile.file_like("BadName")).delete_downstream_parts( reload_cache=True, disable_warnings=False ) - assert "No merge deletes found" in caplog.text, "No warning issued." + assert "No part deletes found" in caplog.text, "No warning issued." def test_ddm_dry_run(Nwbfile, common, sgp, pos_merge_tables, lin_v1): From d030e7adce9a95871a330205695e1851dcfc5266 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 4 Jun 2024 16:02:15 -0500 Subject: [PATCH 02/14] WIP: Bidirectional RestrGraph, remove TableChains --- src/spyglass/utils/dj_graph.py | 255 ++++++++++++++------------------- src/spyglass/utils/dj_mixin.py | 172 ++++++++-------------- 2 files changed, 167 insertions(+), 260 deletions(-) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 4526fa703..e9d64e919 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -6,11 +6,10 @@ from abc import ABC, abstractmethod from collections.abc import KeysView from enum import Enum -from functools import cached_property +from functools import cached_property, partial from itertools import chain as iter_chain -from typing import Any, Dict, List, Set, Tuple, Union +from typing import Any, Dict, Iterable, List, Set, Tuple, Union -import datajoint as dj from datajoint import FreeTable, Table from datajoint.condition import make_condition from datajoint.dependencies import unite_master_parts @@ -70,6 +69,7 @@ class AbstractGraph(ABC): ------- cascade: Abstract method implemented by child classes cascade1: Cascade a restriction up/down the graph, recursively + ft_from_list: Return non-empty FreeTable objects from list of table names Properties ---------- @@ -110,6 +110,24 @@ def cascade(self): """Cascade restrictions through graph.""" raise NotImplementedError("Child class mut implement `cascade` method") + # --------------------------- Dunder Properties --------------------------- + + def __repr__(self): + l_str = ( + ",\n\t".join(self.leaves) + "\n" + if self.leaves + else self._camel(self.seed_table) + ) + casc_str = "Cascaded" if self.cascaded else "Uncascaded" + return f"{casc_str} {self.__class__.__name__}(\n\t{l_str})" + + def __getitem__(self, index: Union[int, str]): + all_ft_names = [t.full_table_name for t in self.all_ft] + return fuzzy_get(index, all_ft_names, self.all_ft) + + def __len__(self): + return len(self.all_ft) + # ---------------------------- Logging Helpers ---------------------------- def _log_truncate(self, log_str: str, max_len: int = 80): @@ -137,19 +155,19 @@ def _print_restr(self): # ------------------------------ Graph Nodes ------------------------------ - def _ensure_name(self, table: Union[str, Table] = None) -> str: + def _ensure_names(self, table: Union[str, Table] = None) -> str: """Ensure table is a string.""" if table is None: return None if isinstance(table, str): return table - if isinstance(table, list): - return [self._ensure_name(t) for t in table] + if isinstance(table, Iterable): + return [self._ensure_names(t) for t in table] return getattr(table, "full_table_name", None) def _get_node(self, table: Union[str, Table]): """Get node from graph.""" - table = self._ensure_name(table) + table = self._ensure_names(table) if not (node := self.graph.nodes.get(table)): raise ValueError( f"Table {table} not found in graph." @@ -174,8 +192,8 @@ def _get_edge(self, child: str, parent: str) -> Tuple[bool, Dict[str, str]]: Tuple of boolean indicating direction and edge data. True if child is child of parent. """ - child = self._ensure_name(child) - parent = self._ensure_name(parent) + child = self._ensure_names(child) + parent = self._ensure_names(parent) if edge := self.graph.get_edge_data(parent, child): return False, edge @@ -195,7 +213,7 @@ def _get_edge(self, child: str, parent: str) -> Tuple[bool, Dict[str, str]]: def _get_restr(self, table): """Get restriction from graph node.""" - return self._get_node(self._ensure_name(table)).get("restr") + return self._get_node(self._ensure_names(table)).get("restr") def _set_restr(self, table, restriction, replace=False): """Add restriction to graph node. If one exists, merge with new.""" @@ -220,7 +238,7 @@ def _set_restr(self, table, restriction, replace=False): def _get_ft(self, table, with_restr=False): """Get FreeTable from graph node. If one doesn't exist, create it.""" - table = self._ensure_name(table) + table = self._ensure_names(table) if with_restr: if not (restr := self._get_restr(table) or False): self._log_truncate(f"No restriction for {table}") @@ -233,20 +251,40 @@ def _get_ft(self, table, with_restr=False): return ft & restr - def _and_parts(self, table): - """Return table, its master and parts.""" - ret = [table] - if master := get_master(table): - ret.append(master) - if parts := self._get_ft(table).parts(): - ret.extend(parts) - return ret + # ------------------------------ Ignore Nodes ------------------------------ def _ignore_peripheral(self): """Ignore peripheral tables in graph traversal.""" self.no_visit.update(PERIPHERAL_TABLES) self.undirect_graph.remove_nodes_from(PERIPHERAL_TABLES) + def _ignore_from_dest(self, sources: List[str], dest: List[str]): + """Ignore nodes from destination(s) in graph traversal.""" + path_nodes = set() + + for source in self._ensure_names(sources): + for table in self._ensure_names(dest): + partial_path = partial( + shortest_path, source=source, target=table + ) + try: + path = partial_path(self.graph) + dir = "Directed" + except (NodeNotFound, NetworkXNoPath): + try: # Try undirected graph + path = partial_path(self.undirect_graph) + dir = "Undirect" + except (NodeNotFound, NetworkXNoPath): + path = ["NONE"] + dir = "NoPath " + self._log_truncate( + f"{dir}: {self._camel(source)} -> {self._camel(table)}: {path}" + ) + path_nodes.update(path) + + unused_nodes = set(self.graph.nodes) - path_nodes + self.no_visit.update(unused_nodes) + # ---------------------------- Graph Traversal ----------------------------- def _bridge_restr( @@ -337,11 +375,16 @@ def _get_next_tables(self, table: str, direction: Direction) -> Tuple: G = self.graph dir_dict = {"direction": direction} - bonus = {} + bonus = {} # Add master and parts to next tables direction = Direction(direction) if direction == Direction.UP: next_func = G.parents - bonus.update({part: {} for part in self._get_ft(table).parts()}) + table_ft = self._get_ft(table) + for part in table_ft.parts(): # Assumes parts do not alias master + bonus[part] = { + "attr_map": {k: k for k in table_ft.primary_key}, + **dir_dict, + } elif direction == Direction.DOWN: next_func = G.children if (master_name := get_master(table)) != "": @@ -386,9 +429,12 @@ def cascade1( next_tables, next_func = self._get_next_tables(table, direction) - self._log_truncate( - f"Checking {count:>2}: {self._camel(next_tables.keys())}" - ) + if next_list := next_tables.keys(): + self._log_truncate( + f"Checking {count:>2}: {self._camel(table)}" + + f" -> {self._camel(next_list)}" + ) + for next_table, data in next_tables.items(): if next_table.isnumeric(): # Skip alias nodes next_table, data = next_func(next_table).popitem() @@ -470,10 +516,10 @@ def graph_distance(self, table1: str = None, table2: str = None) -> int: return 99 self.cascade() - tables = [self._ensure_name(t) for t in tables] + tables = [self._ensure_names(t) for t in tables] if sort_from: - sort_from = self._ensure_name(sort_from) + sort_from = self._ensure_names(sort_from) tables = sorted( tables, key=lambda t: graph_distance(sort_from, t), @@ -498,13 +544,11 @@ class RestrGraph(AbstractGraph): def __init__( self, seed_table: Table, - table_name: str = None, - restriction: str = None, leaves: List[Dict[str, str]] = None, + destinations: List[str] = None, direction: Direction = "up", cascade: bool = False, verbose: bool = False, - ignore_peripheral: bool = False, **kwargs, ): """Use graph to cascade restrictions up from leaves to all ancestors. @@ -519,13 +563,12 @@ def __init__( ---------- seed_table : Table Table to use to establish connection and graph - table_name : str, optional - Table name of single leaf, default None - restriction : str, optional - Restriction to apply to leaf. default None leaves : Dict[str, str], optional List of dictionaries with keys table_name and restriction. One entry per leaf node. Default None. + destinations : List[str], optional + List of endpoints of interest in the graph. Default None. Used to + ignore nodes not in the path(s) to the destination(s). direction : Direction, optional Direction to cascade. Default 'up' cascade : bool, optional @@ -539,29 +582,23 @@ def __init__( """ super().__init__(seed_table, verbose=verbose) - self.add_leaf( - table_name=table_name, restriction=restriction, direction=direction - ) self.add_leaves(leaves) - if ignore_peripheral: - self._ignore_peripheral() - if cascade: - self.cascade(direction=direction) - - # --------------------------- Dunder Properties --------------------------- + if destinations: + if not isinstance(destinations, Iterable): + destinations = [destinations] + self._ignore_from_dest(self.leaves, destinations) - def __repr__(self): - l_str = ",\n\t".join(self.leaves) + "\n" if self.leaves else "" - processed = "Cascaded" if self.cascaded else "Uncascaded" - return f"{processed} {self.__class__.__name__}(\n\t{l_str})" + dir_list = ["up", "down"] if direction == "both" else [direction] - def __getitem__(self, index: Union[int, str]): - all_ft_names = [t.full_table_name for t in self.all_ft] - return fuzzy_get(index, all_ft_names, self.all_ft) - - def __len__(self): - return len(self.all_ft) + if cascade: + for dir in dir_list: + self._log_truncate(f"Start {dir:<4}: {self.leaves}") + self.cascade(direction=dir) + self.cascaded = False + self.visited -= self.leaves + self.cascaded = True + self.visited |= self.leaves # ---------------------------- Public Properties -------------------------- @@ -619,10 +656,12 @@ def _process_leaves(self, leaves=None, default_restriction=True): {"table_name": leaf, "restriction": default_restriction} for leaf in leaves ] - if all(isinstance(leaf, dict) for leaf in leaves) and not all( - leaf.get("table_name") for leaf in leaves - ): - raise ValueError(f"All leaves must have table_name: {leaves}") + if all(isinstance(leaf, dict) for leaf in leaves): + leaves = [ + {"table_name": k, "restriction": v} + for leaf in leaves + for k, v in leaf.items() + ] return unique_dicts(leaves) @@ -669,6 +708,7 @@ def cascade(self, show_progress=None, direction="up") -> None: Show tqdm progress bar. Default to verbose setting. """ if self.cascaded: + self._log_truncate("Already cascaded") return to_visit = self.leaves - self.visited @@ -680,7 +720,9 @@ def cascade(self, show_progress=None, direction="up") -> None: disable=not (show_progress or self.verbose), ): restr = self._get_restr(table) - self._log_truncate(f"Start {table}: {restr}") + self._log_truncate( + f"Start {direction:<4}: {self._camel(table)}, {restr}" + ) self.cascade1(table, restr, direction=direction) self.cascade_files() @@ -739,90 +781,6 @@ def file_paths(self) -> List[str]: ] -class TableChains: - """Class for representing chains from parent to Merge table via parts. - - Functions as a plural version of TableChain, allowing a single `cascade` - call across all chains from parent -> Merge table. - - Attributes - ---------- - parent : Table - Parent or origin of chains. - child : Table - Merge table or destination of chains. - connection : datajoint.Connection, optional - Connection to database used to create FreeTable objects. Defaults to - parent.connection. - part_names : List[str] - List of full table names of child parts. - chains : List[TableChain] - List of TableChain objects for each part in child. - has_link : bool - Cached attribute to store whether parent is linked to child via any of - child parts. False if (a) child is not in parent.descendants or (b) - nx.NetworkXNoPath is raised by nx.shortest_path for all chains. - - Methods - ------- - __init__(parent, child, connection=None) - Initialize TableChains with parent and child tables. - __repr__() - Return full representation of chains. - Multiline parent -> child for each chain. - __len__() - Return number of chains with links. - __getitem__(index: Union[int, str]) - Return TableChain object at index, or use substring of table name. - cascade(restriction: str = None) - Return list of cascade for each chain in self.chains. - """ - - def __init__(self, parent, child, direction=Direction.DOWN, verbose=False): - self.parent = parent - self.child = child - self.connection = parent.connection - self.part_names = child.parts() - self.chains = [ - TableChain(parent, part, direction=direction, verbose=verbose) - for part in self.part_names - ] - self.has_link = any([chain.has_link for chain in self.chains]) - - # --------------------------- Dunder Properties --------------------------- - - def __repr__(self): - l_str = ",\n\t".join([str(c) for c in self.chains]) + "\n" - return f"{self.__class__.__name__}(\n\t{l_str})" - - def __len__(self): - return len([c for c in self.chains if c.has_link]) - - def __getitem__(self, index: Union[int, str]): - """Return FreeTable object at index.""" - return fuzzy_get(index, self.part_names, self.chains) - - # ---------------------------- Public Properties -------------------------- - - @property - def max_len(self): - """Return length of longest chain.""" - return max([len(chain) for chain in self.chains]) - - # ------------------------------ Graph Traversal -------------------------- - - def cascade( - self, restriction: str = None, direction: Direction = Direction.DOWN - ): - """Return list of cascades for each chain in self.chains.""" - restriction = restriction or self.parent.restriction or True - cascades = [] - for chain in self.chains: - if joined := chain.cascade(restriction, direction): - cascades.append(joined) - return cascades - - class TableChain(RestrGraph): """Class for representing a chain of tables. @@ -872,8 +830,8 @@ def __init__( if not allow_merge and child is not None and is_merge_table(child): raise TypeError("Child is a merge table. Use TableChains instead.") - self.parent = self._ensure_name(parent) - self.child = self._ensure_name(child) + self.parent = self._ensure_names(parent) + self.child = self._ensure_names(child) if not self.parent and not self.child: raise ValueError("Parent or child table required.") @@ -884,7 +842,7 @@ def __init__( super().__init__(seed_table=seed_table, verbose=verbose) self._ignore_peripheral() - self.no_visit.update(self._ensure_name(banned_tables) or []) + self.no_visit.update(self._ensure_names(banned_tables) or []) self.no_visit.difference_update([self.parent, self.child]) self.searched_tables = set() self.found_restr = False @@ -1008,6 +966,15 @@ def cascade_search(self) -> None: + f"Restr: {restriction}" ) + def _and_parts(self, table): + """Return table, its master and parts.""" + ret = [table] + if master := get_master(table): + ret.append(master) + if parts := self._get_ft(table).parts(): + ret.extend(parts) + return ret + def _set_found_vars(self, table): """Set found_restr and searched_tables.""" self._set_restr(table, self.search_restr, replace=True) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 19227b567..6232d4b67 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -1,6 +1,5 @@ from atexit import register as exit_register from atexit import unregister as exit_unregister -from collections import OrderedDict from contextlib import nullcontext from functools import cached_property from inspect import stack as inspect_stack @@ -262,29 +261,27 @@ def _import_merge_tables(self): ) @cached_property - def _part_masters(self) -> Dict[str, dj.FreeTable]: - """Dict of part tables downstream of self: {camel_name: FreeTable}. + def _part_masters(self) -> set: + """Set of master tables downstream of self. - Cache of masters of self.descendants(as_objects=True). - Part must have other parent(s) besides master. + Cache of masters in self.descendants(as_objects=True) with another + foreign key reference in the part. Used for delete_downstream_parts. """ self.connection.dependencies.load() - part_masters = {} - visited = set() + part_masters = set() def search_descendants(parent): for desc in parent.descendants(as_objects=True): - if ( # Check if has master, no other fk, or already in cache - not (master_name := get_master(desc.full_table_name)) - or not set(desc.parents()) - set([master_name]) - or master_name in part_masters + if ( # Check if has master, is part + not (master := get_master(desc.full_table_name)) + # has other non-master parent + or not set(desc.parents()) - set([master]) + or master in part_masters # already in cache ): continue - master_ft = dj.FreeTable(self.connection, master_name) - part_masters[to_camel_case(master_name)] = master_ft - if master_name not in visited: - visited.add(master_name) - search_descendants(master_ft) + if master not in part_masters: + part_masters.add(master) + search_descendants(dj.FreeTable(self.connection, master)) try: _ = search_descendants(self) @@ -298,73 +295,20 @@ def search_descendants(parent): logger.info( f"Building part-parent cache for {self.camel_name}.\n\t" - + f"Found {len(part_masters)} downstream merge tables" + + f"Found {len(part_masters)} downstream part tables" ) return part_masters - @cached_property - def _merge_chains(self) -> OrderedDict[str, List[dj.FreeTable]]: - """Dict of chains to parts downstream of self - - Format: {camel_name: TableChains}. - - For each table found in _part_masters, find the path from self to - master via merge parts. If the path is valid, add it to the dict. Cache - prevents need to recompute whenever delete_downstream_part is called - with a new restriction. To recompute, add `reload_cache=True` to - delete_downstream_part call. - """ - from spyglass.utils.dj_graph import TableChains # noqa F401 - - merge_chains = {} - for name, merge_table in self._part_masters.items(): - chains = TableChains(self, merge_table) - if len(chains): - merge_chains[name] = chains - - # self->master chains ordered by max_len as a proxy for downstream-ness. - # A better approach would sort each self->part independently. - - return OrderedDict( - sorted( - merge_chains.items(), key=lambda x: x[1].max_len, reverse=True - ) - ) - - def _get_chain(self, substring): - """Return chain from self to merge table with substring in name.""" - for name, chain in self._merge_chains.items(): - if substring.lower() in name: - return chain - raise ValueError(f"No chain found with '{substring}' in name.") - - def _commit_part_deletes( - self, merge_join_dict: Dict[str, List[QueryExpression]], **kwargs - ) -> None: - """Commit merge deletes. - - Parameters - ---------- - merge_join_dict : Dict[str, List[QueryExpression]] - Dictionary of merge tables and their joins. Uses 'merge_id' primary - key to restrict delete. - - Extracted for use in cautious_delete and delete_downstream_parts.""" - for table_name, part_restr in merge_join_dict.items(): - table = self._part_masters[table_name] - keys = [part.proj().fetch(as_dict=True) for part in part_restr] - (table & keys).delete(**kwargs) - def delete_downstream_parts( self, restriction: str = None, dry_run: bool = True, reload_cache: bool = False, disable_warning: bool = False, - return_parts: bool = True, + return_graph: bool = False, **kwargs, - ) -> Union[List[QueryExpression], Dict[str, List[QueryExpression]]]: + ) -> List[dj.FreeTable]: """Delete downstream merge table entries associated with restriction. Requires caching of merge tables and links, which is slow on first call. @@ -381,24 +325,33 @@ def delete_downstream_parts( If True, reload merge cache. Default False. disable_warning : bool, optional If True, do not warn if no merge tables found. Default False. - return_parts : bool, optional - If True, return list of merge part entries to be deleted. Default + return_graph: bool, optional + If True, return RestrGraph object used to identify downstream + tables. Default False, return list of part FreeTables. True. If False, return dictionary of merge tables and their joins. **kwargs : Any Passed to datajoint.table.Table.delete. """ + from spyglass.utils.dj_graph import RestrGraph # noqa F401 + if reload_cache: - for attr in ["_merge_tables", "_merge_chains"]: - _ = self.__dict__.pop(attr, None) + _ = self.__dict__.pop("_part_masters", None) restriction = restriction or self.restriction or True - merge_join_dict = {} - for name, chain in self._merge_chains.items(): - if join := chain.cascade(restriction, direction="down"): - merge_join_dict[name] = join + restr_graph = RestrGraph( + seed_table=self, + leaves={self.full_table_name: restriction}, + direction="down", + cascade=True, + verbose=False, + ) + + master_fts = restr_graph.ft_from_list( + self._part_masters, sort_from=self.full_table_name + ) - if not merge_join_dict and not disable_warning: + if not master_fts and not disable_warning: logger.warning( f"No part deletes found w/ {self.camel_name} & " + f"{restriction}.\n\tIf this is unexpected, try importing " @@ -406,30 +359,16 @@ def delete_downstream_parts( ) if dry_run: - return merge_join_dict.values() if return_parts else merge_join_dict + return restr_graph if return_graph else master_fts - self._commit_part_deletes(merge_join_dict, **kwargs) + for master_ft in master_fts: + master_ft.delete(**kwargs) def ddp( - self, - restriction: str = None, - dry_run: bool = True, - reload_cache: bool = False, - disable_warning: bool = False, - return_parts: bool = True, - *args, - **kwargs, + self, *args, **kwargs ) -> Union[List[QueryExpression], Dict[str, List[QueryExpression]]]: """Alias for delete_downstream_parts.""" - return self.delete_downstream_parts( - restriction=restriction, - dry_run=dry_run, - reload_cache=reload_cache, - disable_warning=disable_warning, - return_parts=return_parts, - *args, - **kwargs, - ) + return self.delete_downstream_parts(*args, **kwargs) # ---------------------------- cautious_delete ---------------------------- @@ -554,16 +493,14 @@ def _check_delete_permission(self) -> None: logger.info(f"Queueing delete for session(s):\n{sess_summary}") @cached_property - def _usage_table(self): + def _cautious_del_tbl(self): """Temporary inclusion for usage tracking.""" from spyglass.common.common_usage import CautiousDelete return CautiousDelete() - def _log_delete(self, start, merge_deletes=None, super_delete=False): + def _log_delete(self, start, master_deletes=None, super_delete=False): """Log use of cautious_delete.""" - if isinstance(merge_deletes, QueryExpression): - merge_deletes = merge_deletes.fetch(as_dict=True) safe_insert = dict( duration=time() - start, dj_user=dj.config["database.user"], @@ -572,15 +509,15 @@ def _log_delete(self, start, merge_deletes=None, super_delete=False): restr_str = "Super delete: " if super_delete else "" restr_str += "".join(self.restriction) if self.restriction else "None" try: - self._usage_table.insert1( + self._cautious_del_tbl.insert1( dict( **safe_insert, restriction=restr_str[:255], - merge_deletes=merge_deletes, + merge_deletes=master_deletes, ) ) except (DataJointError, DataError): - self._usage_table.insert1( + self._cautious_del_tbl.insert1( dict(**safe_insert, restriction="Unknown") ) @@ -606,10 +543,9 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): if not force_permission: self._check_delete_permission() - merge_deletes = self.delete_downstream_parts( + master_fts = self.delete_downstream_parts( dry_run=True, disable_warning=True, - return_parts=False, ) safemode = ( @@ -617,17 +553,21 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): if kwargs.get("safemode") is None else kwargs["safemode"] ) + _ = kwargs.pop("safemode", None) - if merge_deletes: - for table, content in merge_deletes.items(): - count = sum([len(part) for part in content]) - dj_logger.info(f"Merge: Deleting {count} rows from {table}") + if master_fts: + for part in master_fts: + dj_logger.info( + f"Spyglass: Deleting {len(part)} rows from " + + f"{part.full_table_name}" + ) if ( - not self._test_mode + self._test_mode or not safemode or user_choice("Commit deletes?", default="no") == "yes" ): - self._commit_part_deletes(merge_deletes, **kwargs) + for master in master_fts: # safemode off b/c already checked + master.delete(safemode=False, **kwargs) else: logger.info("Delete aborted.") self._log_delete(start) @@ -635,7 +575,7 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): super().delete(*args, **kwargs) # Additional confirm here - self._log_delete(start=start, merge_deletes=merge_deletes) + self._log_delete(start=start, master_deletes=master_fts) def cdel(self, force_permission=False, *args, **kwargs): """Alias for cautious_delete.""" From 0b83423d1addbbef697cebf1772d0cf6c24ac3fc Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 6 Jun 2024 11:58:59 -0500 Subject: [PATCH 03/14] WIP: bridge up to interval list --- src/spyglass/common/common_interval.py | 7 + src/spyglass/common/common_usage.py | 2 +- src/spyglass/utils/dj_graph.py | 176 ++++++++++++++++++------- src/spyglass/utils/dj_mixin.py | 103 +++++++++++---- 4 files changed, 208 insertions(+), 80 deletions(-) diff --git a/src/spyglass/common/common_interval.py b/src/spyglass/common/common_interval.py index 66e82bda8..56b3a15b5 100644 --- a/src/spyglass/common/common_interval.py +++ b/src/spyglass/common/common_interval.py @@ -7,6 +7,7 @@ import pandas as pd from spyglass.utils import SpyglassMixin, logger +from spyglass.utils.dj_helper_fn import get_child_tables from .common_session import Session # noqa: F401 @@ -152,6 +153,12 @@ def plot_epoch_pos_raw_intervals(self, figsize=(20, 5), return_fig=False): if return_fig: return fig + def nightly_cleanup(self, dry_run=True): + orphans = self - get_child_tables(self) + if dry_run: + return orphans + orphans.super_delete() + def intervals_by_length(interval_list, min_length=0.0, max_length=1e10): """Select intervals of certain lengths from an interval list. diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index dae4f7842..31823795f 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -251,7 +251,7 @@ def make(self, key): # Writes but does not run mysqldump. Assumes single version per paper. version_key = query.fetch("spyglass_version", as_dict=True)[0] self.write_export( - free_tables=restr_graph.all_ft, **paper_key, **version_key + free_tables=restr_graph.restr_ft, **paper_key, **version_key ) self.insert1({**key, **paper_key}) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index e9d64e919..c23dbec80 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from collections.abc import KeysView +from copy import deepcopy from enum import Enum from functools import cached_property, partial from itertools import chain as iter_chain @@ -13,6 +14,7 @@ from datajoint import FreeTable, Table from datajoint.condition import make_condition from datajoint.dependencies import unite_master_parts +from datajoint.user_tables import TableMeta from datajoint.utils import get_master, to_camel_case from networkx import ( NetworkXNoPath, @@ -74,6 +76,7 @@ class AbstractGraph(ABC): Properties ---------- all_ft: Get all FreeTables for visited nodes with restrictions applied. + restr_ft: Get non-empty FreeTables for visited nodes with restrictions. as_dict: Get visited nodes as a list of dictionaries of {table_name: restriction} """ @@ -91,8 +94,14 @@ def __init__(self, seed_table: Table, verbose: bool = False, **kwargs): self.seed_table = seed_table self.connection = seed_table.connection - self.graph = seed_table.connection.dependencies - self.graph.load() + # Deepcopy graph to avoid seed `load()` resetting custom attributes + seed_table.connection.dependencies.load() + graph = seed_table.connection.dependencies + orig_conn = graph._conn # Cannot deepcopy connection + graph._conn = None + self.graph = deepcopy(graph) + graph._conn = orig_conn + # undirect not needed in all cases but need to do before adding ft nodes self.undirect_graph = self.graph.to_undirected() @@ -114,7 +123,7 @@ def cascade(self): def __repr__(self): l_str = ( - ",\n\t".join(self.leaves) + "\n" + ",\n\t".join(self._camel(self.leaves)) + "\n" if self.leaves else self._camel(self.seed_table) ) @@ -126,7 +135,7 @@ def __getitem__(self, index: Union[int, str]): return fuzzy_get(index, all_ft_names, self.all_ft) def __len__(self): - return len(self.all_ft) + return len(self.restr_ft) # ---------------------------- Logging Helpers ---------------------------- @@ -144,6 +153,7 @@ def _camel(self, table): table = list(table) if not isinstance(table, list): table = [table] + table = self._ensure_names(table) ret = [to_camel_case(t.split(".")[-1].strip("`")) for t in table] return ret[0] if len(ret) == 1 else ret @@ -155,13 +165,17 @@ def _print_restr(self): # ------------------------------ Graph Nodes ------------------------------ - def _ensure_names(self, table: Union[str, Table] = None) -> str: + def _ensure_names( + self, table: Union[str, Table] = None + ) -> Union[str, List[str]]: """Ensure table is a string.""" if table is None: return None if isinstance(table, str): return table - if isinstance(table, Iterable): + if isinstance(table, Iterable) and not isinstance( + table, (Table, TableMeta) + ): return [self._ensure_names(t) for t in table] return getattr(table, "full_table_name", None) @@ -177,6 +191,7 @@ def _get_node(self, table: Union[str, Table]): def _set_node(self, table, attr: str = "ft", value: Any = None): """Set attribute on node. General helper for various attributes.""" + table = self._ensure_names(table) _ = self._get_node(table) # Ensure node exists self.graph.nodes[table][attr] = value @@ -224,6 +239,7 @@ def _set_restr(self, table, restriction, replace=False): else restriction ) existing = self._get_restr(table) + if not replace and existing: if restriction == existing: return @@ -236,12 +252,13 @@ def _set_restr(self, table, restriction, replace=False): self._set_node(table, "restr", restriction) - def _get_ft(self, table, with_restr=False): + def _get_ft(self, table, with_restr=False, warn=True): """Get FreeTable from graph node. If one doesn't exist, create it.""" table = self._ensure_names(table) if with_restr: if not (restr := self._get_restr(table) or False): - self._log_truncate(f"No restriction for {table}") + if warn: + self._log_truncate(f"No restr for {self._camel(table)}") else: restr = True @@ -478,15 +495,21 @@ def all_ft(self): Topological sort logic adopted from datajoint.diagram. """ - self.cascade() + self.cascade(warn=False) nodes = [n for n in self.visited if not n.isnumeric()] sorted_nodes = unite_master_parts( list(topological_sort(self.graph.subgraph(nodes))) ) - all_ft = [ - self._get_ft(table, with_restr=True) for table in sorted_nodes + ret = [ + self._get_ft(table, with_restr=True, warn=False) + for table in sorted_nodes ] - return [ft for ft in all_ft if len(ft) > 0] + return ret + + @property + def restr_ft(self): + """Get non-empty restricted FreeTables from all visited nodes.""" + return [ft for ft in self.all_ft if len(ft) > 0] def ft_from_list( self, @@ -515,7 +538,7 @@ def graph_distance(self, table1: str = None, table2: str = None) -> int: except (NodeNotFound, NetworkXNoPath): return 99 - self.cascade() + self.cascade(warn=False) tables = [self._ensure_names(t) for t in tables] if sort_from: @@ -525,7 +548,10 @@ def graph_distance(self, table1: str = None, table2: str = None) -> int: key=lambda t: graph_distance(sort_from, t), ) - fts = [self._get_ft(table, with_restr=with_restr) for table in tables] + fts = [ + self._get_ft(table, with_restr=with_restr, warn=False) + for table in tables + ] return [ft for ft in fts if len(ft) > 0] @@ -593,7 +619,7 @@ def __init__( if cascade: for dir in dir_list: - self._log_truncate(f"Start {dir:<4}: {self.leaves}") + self._log_truncate(f"Start {dir:<4} : {self.leaves}") self.cascade(direction=dir) self.cascaded = False self.visited -= self.leaves @@ -656,12 +682,19 @@ def _process_leaves(self, leaves=None, default_restriction=True): {"table_name": leaf, "restriction": default_restriction} for leaf in leaves ] + hashable = True if all(isinstance(leaf, dict) for leaf in leaves): - leaves = [ - {"table_name": k, "restriction": v} - for leaf in leaves - for k, v in leaf.items() - ] + new_leaves = [] + for leaf in leaves: + for table, restr in leaf.items(): + if not isinstance(restr, (str, dict)): + hashable = False # likely a dj.AndList + new_leaves.append( + {"table_name": table, "restriction": restr} + ) + if not hashable: + return new_leaves + leaves = new_leaves return unique_dicts(leaves) @@ -699,7 +732,7 @@ def add_leaves( # ------------------------------ Graph Traversal -------------------------- - def cascade(self, show_progress=None, direction="up") -> None: + def cascade(self, show_progress=None, direction="up", warn=True) -> None: """Cascade all restrictions up the graph. Parameters @@ -708,7 +741,8 @@ def cascade(self, show_progress=None, direction="up") -> None: Show tqdm progress bar. Default to verbose setting. """ if self.cascaded: - self._log_truncate("Already cascaded") + if warn: + self._log_truncate("Already cascaded") return to_visit = self.leaves - self.visited @@ -725,24 +759,11 @@ def cascade(self, show_progress=None, direction="up") -> None: ) self.cascade1(table, restr, direction=direction) - self.cascade_files() - self.cascaded = True + self.cascaded = True # Mark here so next step can use `restr_ft` + self.cascade_files() # Otherwise attempts to re-cascade, recursively # ----------------------------- File Handling ----------------------------- - def _get_files(self, table): - """Get analysis files from graph node.""" - return self._get_node(table).get("files", []) - - def cascade_files(self): - """Set node attribute for analysis files.""" - for table in self.visited: - ft = self._get_ft(table, with_restr=True) - if not set(self.analysis_pk).issubset(ft.heading.names): - continue - files = list(ft.fetch(*self.analysis_pk)) - self._set_node(table, "files", files) - @property def analysis_file_tbl(self) -> Table: """Return the analysis file table. Avoids circular import.""" @@ -750,10 +771,14 @@ def analysis_file_tbl(self) -> Table: return AnalysisNwbfile() - @property - def analysis_pk(self) -> List[str]: - """Return primary key fields from analysis file table.""" - return self.analysis_file_tbl.primary_key + def cascade_files(self): + """Set node attribute for analysis files.""" + analysis_pk = self.analysis_file_tbl.primary_key + for ft in self.restr_ft: + if not set(analysis_pk).issubset(ft.heading.names): + continue + files = list(ft.fetch(*analysis_pk)) + self._set_node(ft, "files", files) @property def file_dict(self) -> Dict[str, List[str]]: @@ -761,8 +786,8 @@ def file_dict(self) -> Dict[str, List[str]]: Included for debugging, to associate files with tables. """ - self.cascade() - return {t: self._get_node(t).get("files", []) for t in self.visited} + self.cascade(warn=False) + return {t: self._get_node(t).get("files", []) for t in self.restr_ft} @property def file_paths(self) -> List[str]: @@ -780,13 +805,66 @@ def file_paths(self) -> List[str]: if file is not None ] + # ---------------------------- Orphan handling ---------------------------- + + @property + def interval_tbl_name(self) -> Table: + """Return the interval list table name. Avoids circular import.""" + from spyglass.common import IntervalList + + return IntervalList.full_table_name + + def find_orphans( + self, part_masters: Union[List[str], List[FreeTable]] + ) -> Tuple[List[FreeTable], List[FreeTable]]: + """ + Find would-be orphaned tables in IntervalList and downstream parts. + + Parameters + ---------- + part_masters : List + List of part master tables to check for orphans. + """ + self.cascade(warn=False) + self._log_truncate("Orphan Search") + + for ft in self.restr_ft: + if self.interval_tbl_name not in ft.parents(): + continue + + _, edge = self._get_edge(ft, self.interval_tbl_name) + interval_restr = self._bridge_restr( + table1=ft.full_table_name, + table2=self.interval_tbl_name, + restr=ft.restriction, + direction=Direction.UP, + **edge, + ) + + self._set_restr( + table=self.interval_tbl_name, + restriction=interval_restr, + ) + + interval_ft = self._get_ft(self.interval_tbl_name, with_restr=True) + self._log_truncate(f"IntervalList entries {len(interval_ft)}") + + upstream = [interval_ft] # As list for extensibility + downstream = self.ft_from_list(part_masters, sort_from=self.seed_table) + + return upstream, downstream + class TableChain(RestrGraph): """Class for representing a chain of tables. A chain is a sequence of tables from parent to child identified by - networkx.shortest_path. Parent -> Merge should use TableChains instead to - handle multiple paths to the respective parts of the Merge table. + networkx.shortest_path from parent to child. To avoid issues with merge + tables, use the Merge table as the child, not the part table. + + Either the parent or child can be omitted if a search_restr is provided. + The missing table will be found by searching for where the restriction + can be applied. Attributes ---------- @@ -798,9 +876,6 @@ class TableChain(RestrGraph): Cached attribute to store whether parent is linked to child. path : List[str] Names of tables along the path from parent to child. - all_ft : List[dj.FreeTable] - List of FreeTable objects for each table in chain with restriction - applied. Methods ------- @@ -813,6 +888,8 @@ class TableChain(RestrGraph): Given a restriction at the beginning, return a restricted FreeTable object at the end of the chain. If direction is 'up', start at the child and move up to the parent. If direction is 'down', start at the parent. + cascade_search() + Search from the leaf node to find where a restriction can be applied. """ def __init__( @@ -843,7 +920,7 @@ def __init__( self._ignore_peripheral() self.no_visit.update(self._ensure_names(banned_tables) or []) - self.no_visit.difference_update([self.parent, self.child]) + self.no_visit.difference_update(set([self.parent, self.child])) self.searched_tables = set() self.found_restr = False self.link_type = None @@ -1062,7 +1139,6 @@ def find_path(self, directed=True) -> List[str]: search_graph = self.graph if not directed: - # FTs in self.graph prevent `to_undirected` from working. self.connection.dependencies.load() search_graph = self.undirect_graph diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 6232d4b67..464d5c6a1 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -347,11 +347,12 @@ def delete_downstream_parts( verbose=False, ) - master_fts = restr_graph.ft_from_list( - self._part_masters, sort_from=self.full_table_name - ) + if return_graph: + return restr_graph - if not master_fts and not disable_warning: + _, down_fts = restr_graph.find_orphans(self._part_masters) + + if not down_fts and not disable_warning: logger.warning( f"No part deletes found w/ {self.camel_name} & " + f"{restriction}.\n\tIf this is unexpected, try importing " @@ -359,9 +360,9 @@ def delete_downstream_parts( ) if dry_run: - return restr_graph if return_graph else master_fts + return down_fts - for master_ft in master_fts: + for master_ft in down_fts: master_ft.delete(**kwargs) def ddp( @@ -374,18 +375,27 @@ def ddp( @cached_property def _delete_deps(self) -> List[Table]: - """List of tables required for delete permission check. + """List of tables required for delete permission and orphan checks. LabMember, LabTeam, and Session are required for delete permission. + common_nwbfile.schema.external is required for deleting orphaned + external files. IntervalList is required for deleting orphaned interval + lists. Used to delay import of tables until needed, avoiding circular imports. - Each of these tables inheits SpyglassMixin. + Each of these tables inherits SpyglassMixin. """ - from spyglass.common import LabMember, LabTeam, Session # noqa F401 + from spyglass.common import ( # noqa F401 + IntervalList, + LabMember, + LabTeam, + Session, + ) + from spyglass.common.common_nwbfile import schema # noqa F401 self._session_pk = Session.primary_key[0] self._member_pk = LabMember.primary_key[0] - return [LabMember, LabTeam, Session] + return [LabMember, LabTeam, Session, schema.external, IntervalList] def _get_exp_summary(self): """Get summary of experimenters for session(s), including NULL. @@ -401,7 +411,7 @@ def _get_exp_summary(self): Summary of experimenters for session(s). """ - Session = self._delete_deps[-1] + Session = self._delete_deps[2] SesExp = Session.Experimenter # Not called in delete permission check, only bare _get_exp_summary @@ -424,7 +434,7 @@ def _session_connection(self): """Path from Session table to self. False if no connection found.""" from spyglass.utils.dj_graph import TableChain # noqa F401 - connection = TableChain(parent=self._delete_deps[-1], child=self) + connection = TableChain(parent=self._delete_deps[2], child=self) return connection if connection.has_link else False @cached_property @@ -450,7 +460,7 @@ def _check_delete_permission(self) -> None: Permission denied because (a) Session has no experimenter, or (b) user is not on a team with Session experimenter(s). """ - LabMember, LabTeam, Session = self._delete_deps + LabMember, LabTeam, Session, _ = self._delete_deps dj_user = dj.config["database.user"] if dj_user in LabMember().admin: # bypass permission check for admin @@ -522,8 +532,10 @@ def _log_delete(self, start, master_deletes=None, super_delete=False): ) # TODO: Intercept datajoint delete confirmation prompt for merge deletes - def cautious_delete(self, force_permission: bool = False, *args, **kwargs): - """Delete table rows after checking user permission. + def cautious_delete( + self, force_permission: bool = False, dry_run=False, *args, **kwargs + ): + """Permission check, then delete potential orphans and table rows. Permission is granted to users listed as admin in LabMember table or to users on a team with with the Session experimenter(s). If the table @@ -531,22 +543,41 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): continues. If the Session has no experimenter, or if the user is not on a team with the Session experimenter(s), a PermissionError is raised. + Potential downstream orphans are deleted first. These are master tables + whose parts have foreign keys so descendants of self. Then, rows from + self are deleted. Last, IntervalList and Nwbfile externals are deleted. + Parameters ---------- force_permission : bool, optional Bypass permission check. Default False. + dry_run : bool, optional + Default False. If True, return items to be deleted as + Tuple[Upstream, Downstream, externals['raw'], externals['analysis']] + If False, delete items. *args, **kwargs : Any Passed to datajoint.table.Table.delete. """ start = time() + external = self._delete_deps[3] - if not force_permission: + if not force_permission or dry_run: self._check_delete_permission() - master_fts = self.delete_downstream_parts( + restr_graph = self.delete_downstream_parts( dry_run=True, disable_warning=True, + return_graph=True, ) + up_fts, down_fts = restr_graph.find_orphans(self._part_masters) + + if dry_run: + return ( + up_fts, + down_fts, + external["raw"].unused(), + external["analysis"].unused(), + ) safemode = ( dj.config.get("safemode", True) @@ -555,19 +586,24 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): ) _ = kwargs.pop("safemode", None) - if master_fts: - for part in master_fts: + if up_fts or down_fts: + for down_ft in down_fts: + dj_logger.info( + f"Spyglass: Deleting {len(down_ft)} rows from " + + f"{down_ft.full_table_name}" + ) + for up_ft in up_fts: dj_logger.info( - f"Spyglass: Deleting {len(part)} rows from " - + f"{part.full_table_name}" + f"Spyglass: Deleting {len(up_ft)} rows from " + + f"{up_ft.full_table_name} after next prompt" ) if ( self._test_mode or not safemode or user_choice("Commit deletes?", default="no") == "yes" ): - for master in master_fts: # safemode off b/c already checked - master.delete(safemode=False, **kwargs) + for down_ft in down_fts: # safemode off b/c already checked + down_ft.delete(safemode=False, **kwargs) else: logger.info("Delete aborted.") self._log_delete(start) @@ -575,15 +611,24 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): super().delete(*args, **kwargs) # Additional confirm here - self._log_delete(start=start, master_deletes=master_fts) + if up_fts: + for up_ft in up_fts: + up_ft.delete(safemode=False, **kwargs) - def cdel(self, force_permission=False, *args, **kwargs): + for ext_type in ["raw", "analysis"]: + external[ext_type].delete( + delete_external=True, display_progress=True + ) + + self._log_delete(start=start, master_deletes=up_fts + down_fts) + + def cdel(self, *args, **kwargs): """Alias for cautious_delete.""" - self.cautious_delete(force_permission=force_permission, *args, **kwargs) + return self.cautious_delete(*args, **kwargs) - def delete(self, force_permission=False, *args, **kwargs): + def delete(self, *args, **kwargs): """Alias for cautious_delete, overwrites datajoint.table.Table.delete""" - self.cautious_delete(force_permission=force_permission, *args, **kwargs) + self.cautious_delete(*args, **kwargs) def super_delete(self, warn=True, *args, **kwargs): """Alias for datajoint.table.Table.delete.""" @@ -596,7 +641,7 @@ def super_delete(self, warn=True, *args, **kwargs): @cached_property def _spyglass_version(self): - """Get Spyglass version from dj.config.""" + """Get Spyglass version.""" from spyglass import __version__ as sg_version return ".".join(sg_version.split(".")[:3]) # Major.Minor.Patch From 40f46c649cbe77c4982b2b3d16d011da4ac56ca5 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 7 Jun 2024 19:17:14 -0500 Subject: [PATCH 04/14] Add tests for new delete --- .gitignore | 1 + pyproject.toml | 2 +- src/spyglass/common/common_usage.py | 18 +-- src/spyglass/utils/dj_graph.py | 179 +++++++++------------------- src/spyglass/utils/dj_mixin.py | 117 ++++++++++-------- tests/common/test_usage.py | 89 ++++++++++++++ tests/conftest.py | 32 ++++- tests/utils/conftest.py | 17 +-- tests/utils/test_chains.py | 31 ++--- tests/utils/test_graph.py | 138 +++++++++++++++++++-- tests/utils/test_merge.py | 19 +++ tests/utils/test_mixin.py | 91 +++++++++----- 12 files changed, 470 insertions(+), 264 deletions(-) create mode 100644 tests/common/test_usage.py diff --git a/.gitignore b/.gitignore index 052080023..2d912b111 100644 --- a/.gitignore +++ b/.gitignore @@ -127,6 +127,7 @@ dmypy.json .pyre/ # Test Data Files +tests/_data/* *.dat *.mda *.rec diff --git a/pyproject.toml b/pyproject.toml index ffb8d0df6..9c74c8c0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,7 +120,7 @@ ignore-words-list = 'nevers' [tool.pytest.ini_options] minversion = "7.0" addopts = [ - "-sv", + # "-sv", # "--sw", # stepwise: resume with next test after failure # "--pdb", # drop into debugger on failure "-p no:warnings", diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index 31823795f..5dca00185 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -14,7 +14,7 @@ from datajoint import config as dj_config from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile -from spyglass.settings import export_dir +from spyglass.settings import export_dir, test_mode from spyglass.utils import SpyglassMixin, logger from spyglass.utils.dj_graph import RestrGraph from spyglass.utils.dj_helper_fn import unique_dicts @@ -98,6 +98,8 @@ def insert1_return_pk(self, key: dict, **kwargs) -> int: export_id = query.fetch1("export_id") export_key = {"export_id": export_id} if query := (Export & export_key): + if test_mode: + query.super_delete(warn=False, safemode=False) query.super_delete(warn=False) logger.info(f"{status} {export_key}") return export_id @@ -169,9 +171,11 @@ def _max_export_id(self, paper_id: str, return_all=False) -> int: all_export_ids = query.fetch("export_id") return all_export_ids if return_all else max(all_export_ids) - def paper_export_id(self, paper_id: str) -> dict: + def paper_export_id(self, paper_id: str, return_all=False) -> dict: """Return the maximum export_id for a paper, used to populate Export.""" - return {"export_id": self._max_export_id(paper_id)} + if not return_all: + return {"export_id": self._max_export_id(paper_id)} + return [{"export_id": id} for id in self._max_export_id(paper_id, True)] @schema @@ -210,11 +214,11 @@ def populate_paper(self, paper_id: Union[str, dict]): self.populate(ExportSelection().paper_export_id(paper_id)) def make(self, key): - query = ExportSelection & key - paper_key = query.fetch("paper_id", as_dict=True)[0] + paper_key = (ExportSelection & key).fetch("paper_id", as_dict=True)[0] + query = ExportSelection & paper_key # Null insertion if export_id is not the maximum for the paper - all_export_ids = query._max_export_id(paper_key, return_all=True) + all_export_ids = ExportSelection()._max_export_id(paper_key, True) max_export_id = max(all_export_ids) if key.get("export_id") != max_export_id: logger.info( @@ -235,7 +239,7 @@ def make(self, key): (self.Table & id_dict).delete_quick() (self.Table & id_dict).delete_quick() - restr_graph = query.get_restr_graph(paper_key) + restr_graph = ExportSelection().get_restr_graph(paper_key) file_paths = unique_dicts( # Original plus upstream files query.list_file_paths(paper_key) + restr_graph.file_paths ) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index c23dbec80..74f9e98b1 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -4,10 +4,9 @@ """ from abc import ABC, abstractmethod -from collections.abc import KeysView from copy import deepcopy from enum import Enum -from functools import cached_property, partial +from functools import cached_property from itertools import chain as iter_chain from typing import Any, Dict, Iterable, List, Set, Tuple, Union @@ -125,14 +124,14 @@ def __repr__(self): l_str = ( ",\n\t".join(self._camel(self.leaves)) + "\n" if self.leaves - else self._camel(self.seed_table) + else "Seed: " + self._camel(self.seed_table) + "\n" ) casc_str = "Cascaded" if self.cascaded else "Uncascaded" return f"{casc_str} {self.__class__.__name__}(\n\t{l_str})" def __getitem__(self, index: Union[int, str]): - all_ft_names = [t.full_table_name for t in self.all_ft] - return fuzzy_get(index, all_ft_names, self.all_ft) + names = [t.full_table_name for t in self.restr_ft] + return fuzzy_get(index, names, self.restr_ft) def __len__(self): return len(self.restr_ft) @@ -149,19 +148,13 @@ def _log_truncate(self, log_str: str, max_len: int = 80): def _camel(self, table): """Convert table name(s) to camel case.""" - if isinstance(table, KeysView): - table = list(table) - if not isinstance(table, list): - table = [table] table = self._ensure_names(table) - ret = [to_camel_case(t.split(".")[-1].strip("`")) for t in table] - return ret[0] if len(ret) == 1 else ret - - def _print_restr(self): - """Print restrictions for debugging.""" - for table in self.visited: - if restr := self._get_restr(table): - logger.info(f"{table}: {restr}") + if isinstance(table, str): + return to_camel_case(table.split(".")[-1].strip("`")) + if isinstance(table, Iterable) and not isinstance( + table, (Table, TableMeta) + ): + return [self._camel(t) for t in table] # ------------------------------ Graph Nodes ------------------------------ @@ -270,37 +263,12 @@ def _get_ft(self, table, with_restr=False, warn=True): # ------------------------------ Ignore Nodes ------------------------------ - def _ignore_peripheral(self): + def _ignore_peripheral(self, except_tables: List[str] = None): """Ignore peripheral tables in graph traversal.""" - self.no_visit.update(PERIPHERAL_TABLES) - self.undirect_graph.remove_nodes_from(PERIPHERAL_TABLES) - - def _ignore_from_dest(self, sources: List[str], dest: List[str]): - """Ignore nodes from destination(s) in graph traversal.""" - path_nodes = set() - - for source in self._ensure_names(sources): - for table in self._ensure_names(dest): - partial_path = partial( - shortest_path, source=source, target=table - ) - try: - path = partial_path(self.graph) - dir = "Directed" - except (NodeNotFound, NetworkXNoPath): - try: # Try undirected graph - path = partial_path(self.undirect_graph) - dir = "Undirect" - except (NodeNotFound, NetworkXNoPath): - path = ["NONE"] - dir = "NoPath " - self._log_truncate( - f"{dir}: {self._camel(source)} -> {self._camel(table)}: {path}" - ) - path_nodes.update(path) - - unused_nodes = set(self.graph.nodes) - path_nodes - self.no_visit.update(unused_nodes) + except_tables = self._ensure_names(except_tables) + ignore_tables = set(PERIPHERAL_TABLES) - set(except_tables or []) + self.no_visit.update(ignore_tables) + self.undirect_graph.remove_nodes_from(ignore_tables) # ---------------------------- Graph Traversal ----------------------------- @@ -516,6 +484,7 @@ def ft_from_list( tables: List[str], with_restr: bool = True, sort_from: str = None, + return_empty: bool = False, ) -> List[FreeTable]: """Return non-empty FreeTable objects from list of table names. @@ -526,7 +495,8 @@ def ft_from_list( with_restr : bool, optional Restrict FreeTable to restriction. Default True. sort_from : str, optional - Table name. Sort by distance from this table. Default None, no sort. + Table name. Sort by decreasing distance from this table. + Default None, no sort. """ def graph_distance(self, table1: str = None, table2: str = None) -> int: @@ -546,6 +516,7 @@ def graph_distance(self, table1: str = None, table2: str = None) -> int: tables = sorted( tables, key=lambda t: graph_distance(sort_from, t), + reverse=True, # sort from farthest to closest ) fts = [ @@ -553,7 +524,7 @@ def graph_distance(self, table1: str = None, table2: str = None) -> int: for table in tables ] - return [ft for ft in fts if len(ft) > 0] + return fts if return_empty else [ft for ft in fts if len(ft) > 0] @property def as_dict(self) -> List[Dict[str, str]]: @@ -610,11 +581,6 @@ def __init__( self.add_leaves(leaves) - if destinations: - if not isinstance(destinations, Iterable): - destinations = [destinations] - self._ignore_from_dest(self.leaves, destinations) - dir_list = ["up", "down"] if direction == "both" else [direction] if cascade: @@ -672,7 +638,13 @@ def add_leaf( self.cascaded = True def _process_leaves(self, leaves=None, default_restriction=True): - """Process leaves to ensure they are unique and have required keys.""" + """Process leaves to ensure they are unique and have required keys. + + Accepts ... + - [str]: table names, use default_restriction + - [{'table_name': str, 'restriction': str}]: used for export + - [{table_name: restriction}]: userd for distance restriction + """ if not leaves: return [] if not isinstance(leaves, list): @@ -686,6 +658,9 @@ def _process_leaves(self, leaves=None, default_restriction=True): if all(isinstance(leaf, dict) for leaf in leaves): new_leaves = [] for leaf in leaves: + if "table_name" in leaf and "restriction" in leaf: + new_leaves.append(leaf) + continue for table, restr in leaf.items(): if not isinstance(restr, (str, dict)): hashable = False # likely a dj.AndList @@ -805,55 +780,6 @@ def file_paths(self) -> List[str]: if file is not None ] - # ---------------------------- Orphan handling ---------------------------- - - @property - def interval_tbl_name(self) -> Table: - """Return the interval list table name. Avoids circular import.""" - from spyglass.common import IntervalList - - return IntervalList.full_table_name - - def find_orphans( - self, part_masters: Union[List[str], List[FreeTable]] - ) -> Tuple[List[FreeTable], List[FreeTable]]: - """ - Find would-be orphaned tables in IntervalList and downstream parts. - - Parameters - ---------- - part_masters : List - List of part master tables to check for orphans. - """ - self.cascade(warn=False) - self._log_truncate("Orphan Search") - - for ft in self.restr_ft: - if self.interval_tbl_name not in ft.parents(): - continue - - _, edge = self._get_edge(ft, self.interval_tbl_name) - interval_restr = self._bridge_restr( - table1=ft.full_table_name, - table2=self.interval_tbl_name, - restr=ft.restriction, - direction=Direction.UP, - **edge, - ) - - self._set_restr( - table=self.interval_tbl_name, - restriction=interval_restr, - ) - - interval_ft = self._get_ft(self.interval_tbl_name, with_restr=True) - self._log_truncate(f"IntervalList entries {len(interval_ft)}") - - upstream = [interval_ft] # As list for extensibility - downstream = self.ft_from_list(part_masters, sort_from=self.seed_table) - - return upstream, downstream - class TableChain(RestrGraph): """Class for representing a chain of tables. @@ -900,25 +826,19 @@ def __init__( search_restr: str = None, cascade: bool = False, verbose: bool = False, - allow_merge: bool = False, banned_tables: List[str] = None, **kwargs, ): - if not allow_merge and child is not None and is_merge_table(child): - raise TypeError("Child is a merge table. Use TableChains instead.") - self.parent = self._ensure_names(parent) self.child = self._ensure_names(child) if not self.parent and not self.child: raise ValueError("Parent or child table required.") - if not search_restr and not (self.parent and self.child): - raise ValueError("Search restriction required to find path.") seed_table = parent if isinstance(parent, Table) else child super().__init__(seed_table=seed_table, verbose=verbose) - self._ignore_peripheral() + self._ignore_peripheral(except_tables=[self.parent, self.child]) self.no_visit.update(self._ensure_names(banned_tables) or []) self.no_visit.difference_update(set([self.parent, self.child])) self.searched_tables = set() @@ -929,6 +849,8 @@ def __init__( self.search_restr = search_restr self.direction = Direction(direction) + if self.parent and self.child and not self.direction: + self.direction = Direction.DOWN self.leaf = None if search_restr and not parent: @@ -942,8 +864,9 @@ def __init__( self.add_leaf(self.leaf, True, cascade=False, direction=direction) if cascade and search_restr: - self.cascade_search() - self.cascade(restriction=search_restr) + self.cascade_search() # only cascade if found or not looking + if (search_restr and self.found_restr) or not search_restr: + self.cascade(restriction=search_restr) self.cascaded = True # --------------------------- Dunder Properties --------------------------- @@ -970,9 +893,6 @@ def __len__(self): return 0 return len(self.path) - def __getitem__(self, index: Union[int, str]): - return fuzzy_get(index, self.path, self.all_ft) - # ---------------------------- Public Properties -------------------------- @property @@ -992,6 +912,12 @@ def path_str(self) -> str: return "No link" return self._link_symbol.join([self._camel(t) for t in self.path]) + @property + def path_ft(self) -> List[FreeTable]: + """Return FreeTables along the path.""" + path_with_ends = set([self.parent, self.child]) | set(self.path) + return self.ft_from_list(path_with_ends, with_restr=True) + # ------------------------------ Graph Nodes ------------------------------ def _set_find_restr(self, table_name, restriction): @@ -1034,6 +960,7 @@ def cascade_search(self) -> None: replace=True, ) if not self.found_restr: + self.link_type = None searched = ( "parents" if self.direction == Direction.UP else "children" ) @@ -1136,11 +1063,7 @@ def find_path(self, directed=True) -> List[str]: List of names in the path. """ source, target = self.parent, self.child - search_graph = self.graph - - if not directed: - self.connection.dependencies.load() - search_graph = self.undirect_graph + search_graph = self.graph if directed else self.undirect_graph search_graph.remove_nodes_from(self.no_visit) @@ -1157,7 +1080,6 @@ def find_path(self, directed=True) -> List[str]: ignore_nodes = self.graph.nodes - set(path) self.no_visit.update(ignore_nodes) - self._log_truncate(f"Ignore : {ignore_nodes}") return path @cached_property @@ -1175,7 +1097,9 @@ def path(self) -> list: return path - def cascade(self, restriction: str = None, direction: Direction = None): + def cascade( + self, restriction: str = None, direction: Direction = None, **kwargs + ): if not self.has_link: return @@ -1191,11 +1115,18 @@ def cascade(self, restriction: str = None, direction: Direction = None): self.cascade1( table=start, - restriction=restriction or self._get_restr(start), + restriction=restriction or self._get_restr(start) or True, direction=direction, replace=True, ) + # Cascade will stop if any restriction is empty, so set rest to None + non_numeric = [t for t in self.path if not t.isnumeric()] + if any(self._get_restr(t) is None for t in non_numeric): + for table in non_numeric: + if table is not start: + self._set_restr(table, False, replace=True) + return self._get_ft(end, with_restr=True) def restrict_by(self, *args, **kwargs) -> None: diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 464d5c6a1..bad393065 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -110,14 +110,14 @@ def _auto_increment(self, key, pk, *args, **kwargs): def file_like(self, name=None, **kwargs): """Convenience method for wildcard search on file name fields.""" if not name: - return self & True + return self attr = None for field in self.heading.names: if "file" in field: attr = field break if not attr: - logger.error(f"No file-like field found in {self.full_table_name}") + logger.error(f"No file_like field found in {self.full_table_name}") return return self & f"{attr} LIKE '%{name}%'" @@ -300,6 +300,43 @@ def search_descendants(parent): return part_masters + def _commit_downstream_delete(self, down_fts, start=None, **kwargs): + """ + Commit delete of downstream parts via down_fts. Logs with _log_delete. + + Used by both delete_downstream_parts and cautious_delete. + """ + start = start or time() + + safemode = ( + dj.config.get("safemode", True) + if kwargs.get("safemode") is None + else kwargs["safemode"] + ) + _ = kwargs.pop("safemode", None) + + ran_deletes = True + if down_fts: + for down_ft in down_fts: + dj_logger.info( + f"Spyglass: Deleting {len(down_ft)} rows from " + + f"{down_ft.full_table_name}" + ) + if ( + self._test_mode + or not safemode + or user_choice("Commit deletes?", default="no") == "yes" + ): + for down_ft in down_fts: # safemode off b/c already checked + down_ft.delete(safemode=False, **kwargs) + else: + logger.info("Delete aborted.") + ran_deletes = False + + self._log_delete(start, del_blob=down_fts if ran_deletes else None) + + return ran_deletes + def delete_downstream_parts( self, restriction: str = None, @@ -334,6 +371,8 @@ def delete_downstream_parts( """ from spyglass.utils.dj_graph import RestrGraph # noqa F401 + start = time() + if reload_cache: _ = self.__dict__.pop("_part_masters", None) @@ -350,7 +389,8 @@ def delete_downstream_parts( if return_graph: return restr_graph - _, down_fts = restr_graph.find_orphans(self._part_masters) + # Depends on distance as a proxy for downstream-ness of each + down_fts = restr_graph.ft_from_list(self._part_masters, sort_from=self) if not down_fts and not disable_warning: logger.warning( @@ -362,8 +402,7 @@ def delete_downstream_parts( if dry_run: return down_fts - for master_ft in down_fts: - master_ft.delete(**kwargs) + self._commit_downstream_delete(down_fts, start, **kwargs) def ddp( self, *args, **kwargs @@ -383,7 +422,7 @@ def _delete_deps(self) -> List[Table]: lists. Used to delay import of tables until needed, avoiding circular imports. - Each of these tables inherits SpyglassMixin. + Each of these tables inheits SpyglassMixin. """ from spyglass.common import ( # noqa F401 IntervalList, @@ -460,7 +499,7 @@ def _check_delete_permission(self) -> None: Permission denied because (a) Session has no experimenter, or (b) user is not on a team with Session experimenter(s). """ - LabMember, LabTeam, Session, _ = self._delete_deps + LabMember, LabTeam, Session, _, _ = self._delete_deps dj_user = dj.config["database.user"] if dj_user in LabMember().admin: # bypass permission check for admin @@ -509,7 +548,7 @@ def _cautious_del_tbl(self): return CautiousDelete() - def _log_delete(self, start, master_deletes=None, super_delete=False): + def _log_delete(self, start, del_blob=None, super_delete=False): """Log use of cautious_delete.""" safe_insert = dict( duration=time() - start, @@ -523,7 +562,7 @@ def _log_delete(self, start, master_deletes=None, super_delete=False): dict( **safe_insert, restriction=restr_str[:255], - merge_deletes=master_deletes, + merge_deletes=del_blob, ) ) except (DataJointError, DataError): @@ -544,8 +583,8 @@ def cautious_delete( a team with the Session experimenter(s), a PermissionError is raised. Potential downstream orphans are deleted first. These are master tables - whose parts have foreign keys so descendants of self. Then, rows from - self are deleted. Last, IntervalList and Nwbfile externals are deleted. + whose parts have foreign keys to descendants of self. Then, rows from + self are deleted. Last, Nwbfile and IntervalList externals are deleted. Parameters ---------- @@ -559,68 +598,37 @@ def cautious_delete( Passed to datajoint.table.Table.delete. """ start = time() - external = self._delete_deps[3] + external, IntervalList = self._delete_deps[3], self._delete_deps[4] if not force_permission or dry_run: self._check_delete_permission() - restr_graph = self.delete_downstream_parts( + down_fts = self.delete_downstream_parts( dry_run=True, disable_warning=True, - return_graph=True, ) - up_fts, down_fts = restr_graph.find_orphans(self._part_masters) if dry_run: return ( - up_fts, down_fts, + IntervalList(), # cleanup func relies on downstream deletes external["raw"].unused(), external["analysis"].unused(), ) - safemode = ( - dj.config.get("safemode", True) - if kwargs.get("safemode") is None - else kwargs["safemode"] - ) - _ = kwargs.pop("safemode", None) - - if up_fts or down_fts: - for down_ft in down_fts: - dj_logger.info( - f"Spyglass: Deleting {len(down_ft)} rows from " - + f"{down_ft.full_table_name}" - ) - for up_ft in up_fts: - dj_logger.info( - f"Spyglass: Deleting {len(up_ft)} rows from " - + f"{up_ft.full_table_name} after next prompt" - ) - if ( - self._test_mode - or not safemode - or user_choice("Commit deletes?", default="no") == "yes" - ): - for down_ft in down_fts: # safemode off b/c already checked - down_ft.delete(safemode=False, **kwargs) - else: - logger.info("Delete aborted.") - self._log_delete(start) - return - - super().delete(*args, **kwargs) # Additional confirm here + if not self._commit_downstream_delete(down_fts, start=start, **kwargs): + return # Abort delete based on user input - if up_fts: - for up_ft in up_fts: - up_ft.delete(safemode=False, **kwargs) + super().delete(*args, **kwargs) # Confirmation here for ext_type in ["raw", "analysis"]: external[ext_type].delete( - delete_external=True, display_progress=True + delete_external_files=True, display_progress=False ) - self._log_delete(start=start, master_deletes=up_fts + down_fts) + _ = IntervalList().nightly_cleanup(dry_run=False) + + self._log_delete(start=start, del_blob=down_fts) def cdel(self, *args, **kwargs): """Alias for cautious_delete.""" @@ -850,6 +858,8 @@ def restrict_by( try: ret = self.restrict(restriction) # Save time trying first if len(ret) < len(self): + # If it actually restricts, if not it might by a dict that + # is not a valid restriction, returned as True logger.warning("Restriction valid for this table. Using as is.") return ret except DataJointError: @@ -875,6 +885,9 @@ def restrict_by( **kwargs, ) + if not graph.found_restr: + return None + if return_graph: return graph diff --git a/tests/common/test_usage.py b/tests/common/test_usage.py new file mode 100644 index 000000000..71449b3e3 --- /dev/null +++ b/tests/common/test_usage.py @@ -0,0 +1,89 @@ +import pytest + + +@pytest.fixture(scope="session") +def export_tbls(common): + from spyglass.common.common_usage import Export, ExportSelection + + return ExportSelection(), Export() + + +@pytest.fixture(scope="session") +def gen_export_selection( + lfp, trodes_pos_v1, track_graph, export_tbls, populate_lfp +): + ExportSelection, _ = export_tbls + _ = populate_lfp + + ExportSelection.start_export(paper_id=1, analysis_id=1) + lfp.v1.LFPV1().fetch_nwb() + trodes_pos_v1.fetch() + ExportSelection.start_export(paper_id=1, analysis_id=2) + track_graph.fetch() + ExportSelection.stop_export() + + yield dict(paper_id=1) + + ExportSelection.stop_export() + ExportSelection.super_delete(warn=False, safemode=False) + + +def test_export_selection_files(gen_export_selection, export_tbls): + ExportSelection, _ = export_tbls + paper_key = gen_export_selection + + len_fi = len(ExportSelection * ExportSelection.File & paper_key) + assert len_fi == 1, "Selection files not captured correctly" + + +def test_export_selection_tables(gen_export_selection, export_tbls): + ExportSelection, _ = export_tbls + paper_key = gen_export_selection + + paper = ExportSelection * ExportSelection.Table & paper_key + len_tbl_1 = len(paper & dict(analysis_id=1)) + len_tbl_2 = len(paper & dict(analysis_id=2)) + assert len_tbl_1 == 2, "Selection tables not captured correctly" + assert len_tbl_2 == 1, "Selection tables not captured correctly" + + +def tests_export_selection_max_id(gen_export_selection, export_tbls): + ExportSelection, _ = export_tbls + _ = gen_export_selection + + exp_id = max(ExportSelection.fetch("export_id")) + got_id = ExportSelection._max_export_id(1) + assert exp_id == got_id, "Max export id not captured correctly" + + +@pytest.fixture(scope="session") +def populate_export(export_tbls, gen_export_selection): + _, Export = export_tbls + Export.populate_paper(**gen_export_selection) + key = (Export & gen_export_selection).fetch("export_id", as_dict=True) + + yield (Export.Table & key), (Export.File & key) + + Export.super_delete(warn=False, safemode=False) + + +def test_export_populate(populate_export): + table, file = populate_export + + assert len(file) == 4, "Export tables not captured correctly" + assert len(table) == 31, "Export files not captured correctly" + + +def test_invalid_export_id(export_tbls): + ExportSelection, _ = export_tbls + ExportSelection.start_export(paper_id=2, analysis_id=1) + with pytest.raises(RuntimeError): + ExportSelection.export_id = 99 + ExportSelection.stop_export() + + +def test_del_export_id(export_tbls): + ExportSelection, _ = export_tbls + ExportSelection.start_export(paper_id=2, analysis_id=1) + del ExportSelection.export_id + assert ExportSelection.export_id == 0, "Export id not reset correctly" diff --git a/tests/conftest.py b/tests/conftest.py index cd9350ff1..94453968b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -301,11 +301,9 @@ def mini_insert( _ = SpikeSortingOutput() - LabMember().insert1( - ["Root User", "Root", "User"], skip_duplicates=not teardown - ) + LabMember().insert1(["Root User", "Root", "User"], skip_duplicates=True) LabMember.LabMemberInfo().insert1( - ["Root User", "email", "root", 1], skip_duplicates=not teardown + ["Root User", "email", "root", 1], skip_duplicates=True ) dj_logger.info("Inserting test data.") @@ -403,6 +401,32 @@ def populate_exception(): yield PopulateException +@pytest.fixture(scope="session") +def frequent_imports(): + """Often needed for graph cascade.""" + from spyglass.common.common_ripple import RippleLFPSelection + from spyglass.decoding.v0.clusterless import UnitMarksIndicatorSelection + from spyglass.decoding.v0.sorted_spikes import ( + SortedSpikesIndicatorSelection, + ) + from spyglass.decoding.v1.core import PositionGroup + from spyglass.lfp.analysis.v1 import LFPBandSelection + from spyglass.mua.v1.mua import MuaEventsV1 + from spyglass.ripple.v1.ripple import RippleTimesV1 + from spyglass.spikesorting.v0.figurl_views import SpikeSortingRecordingView + + return ( + LFPBandSelection, + MuaEventsV1, + PositionGroup, + RippleLFPSelection, + RippleTimesV1, + SortedSpikesIndicatorSelection, + SpikeSortingRecordingView, + UnitMarksIndicatorSelection, + ) + + # ------------------------- FIXTURES, POSITION TABLES ------------------------- diff --git a/tests/utils/conftest.py b/tests/utils/conftest.py index a4bc7f900..726b6b8a7 100644 --- a/tests/utils/conftest.py +++ b/tests/utils/conftest.py @@ -30,23 +30,14 @@ def schema_test(teardown, dj_conn): @pytest.fixture(scope="module") -def chains(Nwbfile): - """Return example TableChains object from Nwbfile.""" - from spyglass.lfp.lfp_merge import LFPOutput # noqa: F401 +def chain(Nwbfile): + """Return example TableChain object from chains.""" from spyglass.linearization.merge import ( LinearizedPositionOutput, ) # noqa: F401 - from spyglass.position.position_merge import PositionOutput # noqa: F401 - - _ = LFPOutput, LinearizedPositionOutput, PositionOutput - - yield Nwbfile._get_chain("linear") - + from spyglass.utils.dj_graph import TableChain -@pytest.fixture(scope="module") -def chain(chains): - """Return example TableChain object from chains.""" - yield chains[0] + yield TableChain(Nwbfile, LinearizedPositionOutput) @pytest.fixture(scope="module") diff --git a/tests/utils/test_chains.py b/tests/utils/test_chains.py index 66d9772c3..093ed5485 100644 --- a/tests/utils/test_chains.py +++ b/tests/utils/test_chains.py @@ -13,27 +13,6 @@ def full_to_camel(t): return to_camel_case(t.split(".")[-1].strip("`")) -def test_chains_repr(chains): - """Test that the repr of a TableChains object is as expected.""" - repr_got = repr(chains) - chain_st = ",\n\t".join([str(c) for c in chains.chains]) + "\n" - repr_exp = f"TableChains(\n\t{chain_st})" - assert repr_got == repr_exp, "Unexpected repr of TableChains object." - - -def test_str_getitem(chains): - """Test getitem of TableChains object.""" - by_int = chains[0] - by_str = chains[chains.part_names[0]] - assert by_int == by_str, "Getitem by int and str not equal." - - -def test_invalid_chain(Nwbfile, pos_merge_tables, TableChain): - """Test that an invalid chain raises an error.""" - with pytest.raises(TypeError): - TableChain(Nwbfile, pos_merge_tables[0]) - - def test_chain_str(chain): """Test that the str of a TableChain object is as expected.""" chain = chain @@ -64,8 +43,8 @@ def test_chain_len(chain): def test_chain_getitem(chain): """Test getitem of TableChain object.""" - by_int = chain[0] - by_str = chain[chain.path[0]] + by_int = str(chain[0]) + by_str = str(chain[chain.restr_ft[0].full_table_name]) assert by_int == by_str, "Getitem by int and str not equal." @@ -76,3 +55,9 @@ def test_nolink_join(no_link_chain): def test_chain_str_no_link(no_link_chain): """Test that the str of a TableChain object with no link is as expected.""" assert str(no_link_chain) == "No link", "Unexpected str of no link chain." + assert repr(no_link_chain) == "No link", "Unexpected repr of no link chain." + + +def test_invalid_chain(TableChain): + with pytest.raises(ValueError): + TableChain() diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index 7d5257a36..ab348ad2b 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -1,4 +1,7 @@ import pytest +from datajoint.utils import to_camel_case + +from tests.conftest import VERBOSE @pytest.fixture(scope="session") @@ -14,8 +17,7 @@ def restr_graph(leaf, verbose, lin_merge_key): yield RestrGraph( seed_table=leaf, - table_name=leaf.full_table_name, - restriction=True, + leaves={leaf.full_table_name: True}, cascade=True, verbose=verbose, ) @@ -26,13 +28,19 @@ def test_rg_repr(restr_graph, leaf): repr_got = repr(restr_graph) assert "cascade" in repr_got.lower(), "Cascade not in repr." - assert leaf.full_table_name in repr_got, "Table name not in repr." + + assert to_camel_case(leaf.table_name) in repr_got, "Table name not in repr." + + +def test_rg_len(restr_graph): + assert len(restr_graph) == len( + restr_graph.restr_ft + ), "Unexpected length of RestrGraph." def test_rg_ft(restr_graph): """Test FreeTable attribute of RestrGraph.""" assert len(restr_graph.leaf_ft) == 1, "Unexpected # of leaf tables." - assert len(restr_graph["spatial"]) == 2, "Unexpected cascaded table length." def test_rg_restr_ft(restr_graph): @@ -43,8 +51,41 @@ def test_rg_restr_ft(restr_graph): def test_rg_file_paths(restr_graph): """Test collection of upstream file paths.""" - paths = [p.get("file_path") for p in restr_graph.file_paths] - assert len(paths) == 2, "Unexpected number of file paths." + assert len(restr_graph.file_paths) == 2, "Unexpected number of file paths." + + +def test_rg_invalid_table(restr_graph): + """Test that an invalid table raises an error.""" + with pytest.raises(ValueError): + restr_graph._get_node("invalid_table") + + +def test_rg_invalid_edge(restr_graph, Nwbfile, common): + """Test that an invalid edge raises an error.""" + with pytest.raises(ValueError): + restr_graph._get_edge(Nwbfile, common.common_behav.PositionSource) + + +def test_rg_restr_subset(restr_graph, leaf): + prev_ft = restr_graph._get_ft(leaf.full_table_name, with_restr=True) + + restr_graph._set_restr(leaf, restriction=False) + + new_ft = restr_graph._get_ft(leaf.full_table_name, with_restr=True) + assert len(prev_ft) == len(new_ft), "Subset sestriction changed length." + + +@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy") +def test_rg_no_restr(caplog, restr_graph, common): + restr_graph._set_restr(common.LabTeam, restriction=False) + restr_graph._get_ft(common.LabTeam.full_table_name, with_restr=True) + assert "No restr" in caplog.text, "No warning logged on no restriction." + + +def test_rg_invalid_direction(restr_graph, leaf): + """Test that an invalid direction raises an error.""" + with pytest.raises(ValueError): + restr_graph._get_next_tables(leaf.full_table_name, "invalid_direction") @pytest.fixture(scope="session") @@ -72,13 +113,14 @@ def test_add_leaf_restr_ft(restr_graph_new_leaf): @pytest.fixture(scope="session") -def restr_graph_root(restr_graph, common, lfp_band, lin_v1): +def restr_graph_root(restr_graph, frequent_imports, common, lfp_band, lin_v1): from spyglass.utils.dj_graph import RestrGraph + _ = frequent_imports # part of cascade, need import + yield RestrGraph( seed_table=common.Session(), - table_name=common.Session.full_table_name, - restriction="True", + leaves={common.Session.full_table_name: "True"}, direction="down", cascade=True, verbose=False, @@ -87,7 +129,7 @@ def restr_graph_root(restr_graph, common, lfp_band, lin_v1): def test_rg_root(restr_graph_root): assert ( - len(restr_graph_root["trodes_pos_v1"]) == 2 + len(restr_graph_root["trodes_pos_v1"]) >= 1 ), "Incomplete cascade from root." @@ -123,6 +165,52 @@ def test_restr_from_downstream(graph_tables, table, restr, expect_n, msg): assert len(graph_tables[table]() << restr) == expect_n, msg +@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") +def test_ban_node(caplog, graph_tables): + search_restr = "sk_attr > 17" + ParentNode = graph_tables["ParentNode"]() + SkNode = graph_tables["SkNode"]() + + ParentNode.ban_search_table(SkNode) + ParentNode >> search_restr + assert "could not be applied" in caplog.text, "Found banned table." + + ParentNode.see_banned_tables() + assert "Banned tables" in caplog.text, "Banned tables not logged." + + ParentNode.unban_search_table(SkNode) + assert len(ParentNode >> search_restr) == 3, "Unban failed." + + +def test_null_restrict_by(graph_tables): + PkNode = graph_tables["PkNode"]() + assert (PkNode >> True) == PkNode, "Null restriction failed." + + +@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") +def test_restrict_by_this_table(caplog, graph_tables): + PkNode = graph_tables["PkNode"]() + PkNode >> "pk_id > 4" + assert "valid for" in caplog.text, "No warning logged without search." + + +def test_invalid_restr_direction(graph_tables): + PkNode = graph_tables["PkNode"]() + with pytest.raises(ValueError): + PkNode.restrict_by("bad_attr > 0", direction="invalid_direction") + + +@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") +def test_warn_nonrestrict(caplog, graph_tables): + ParentNode = graph_tables["ParentNode"]() + restr_parent = ParentNode & "parent_id > 4 AND parent_id < 9" + + restr_parent >> "sk_id > 0" + assert "Same length" in caplog.text, "No warning logged on non-restrict." + restr_parent >> "sk_id > 99" + assert "No entries" in caplog.text, "No warning logged on non-restrict." + + def test_restr_many_to_one(graph_tables_many_to_one): PK = graph_tables_many_to_one["PkSkNode"]() OP = graph_tables_many_to_one["OtherParentNode"]() @@ -137,7 +225,35 @@ def test_restr_many_to_one(graph_tables_many_to_one): ), "Error accepting list of dicts for `>>` for many to one." -def test_restr_invalid(graph_tables): +def test_restr_invalid_err(graph_tables): PkNode = graph_tables["PkNode"]() with pytest.raises(ValueError): len(PkNode << set(["parent_attr > 15", "parent_attr < 20"])) + + +@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") +def test_restr_invalid(caplog, graph_tables): + graph_tables["PkNode"]() << "invalid_restr=1" + assert ( + "could not be applied" in caplog.text + ), "No warning logged on invalid restr." + + +@pytest.fixture(scope="session") +def direction(): + from spyglass.utils.dj_graph import Direction + + yield Direction + + +def test_direction_str(direction): + assert str(direction.UP) == "up", "Direction str not as expected." + + +def test_direction_invert(direction): + assert ~direction.UP == direction("down"), "Direction inversion failed." + + +def test_direction_bool(direction): + assert bool(direction.UP), "Direction bool not as expected." + assert not direction.NONE, "Direction bool not as expected." diff --git a/tests/utils/test_merge.py b/tests/utils/test_merge.py index 9c192c20a..2876555a1 100644 --- a/tests/utils/test_merge.py +++ b/tests/utils/test_merge.py @@ -35,6 +35,25 @@ def test_nwb_table_missing(BadMerge, caplog, schema_test): assert "non-default definition" in txt, "Warning not caught." +@pytest.fixture(scope="function") +def NonMerge(): + from spyglass.utils import SpyglassMixin + + class NonMerge(SpyglassMixin, dj.Manual): + definition = """ + merge_id : uuid + --- + source : varchar(32) + """ + + yield NonMerge + + +def test_non_merge(NonMerge): + with pytest.raises(AttributeError): + NonMerge() + + def test_part_camel(merge_table): example_part = merge_table.parts(camel_case=True)[0] assert "_" not in example_part, "Camel case not applied." diff --git a/tests/utils/test_mixin.py b/tests/utils/test_mixin.py index c70c67b13..7901c1936 100644 --- a/tests/utils/test_mixin.py +++ b/tests/utils/test_mixin.py @@ -8,10 +8,11 @@ def Mixin(): from spyglass.utils import SpyglassMixin - class Mixin(SpyglassMixin, dj.Manual): + class Mixin(SpyglassMixin, dj.Lookup): definition = """ id : int """ + contents = [(0,), (1,)] yield Mixin @@ -32,37 +33,52 @@ def test_nwb_table_missing(schema_test, Mixin): Mixin().fetch_nwb() -def test_merge_detect(Nwbfile, pos_merge_tables): +def test_auto_increment(schema_test, Mixin): + schema_test(Mixin) + ret = Mixin()._auto_increment(key={}, pk="id") + assert ret["id"] == 2, "Auto increment not working." + + +def test_null_file_like(schema_test, Mixin): + schema_test(Mixin) + ret = Mixin().file_like(None) + assert len(ret) == len(Mixin()), "Null file_like not working." + + +@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") +def test_bad_file_like(caplog, schema_test, Mixin): + schema_test(Mixin) + Mixin().file_like("BadName") + assert "No file_like field" in caplog.text, "No warning issued." + + +def test_partmaster_detect(Nwbfile, pos_merge_tables): """Test that the mixin can detect merge children of merge.""" - merges_found = set(Nwbfile._merge_chains.keys()) - merges_expected = set([t.full_table_name for t in pos_merge_tables]) - assert merges_expected.issubset( - merges_found - ), "Merges not detected by mixin." + assert len(Nwbfile._part_masters) >= 14, "Part masters not detected." -def test_merge_chain_join(Nwbfile, pos_merge_tables, lin_v1, lfp_merge_key): +def test_downstream_restrict( + Nwbfile, frequent_imports, pos_merge_tables, lin_v1, lfp_merge_key +): """Test that the mixin can join merge chains.""" + + _ = frequent_imports # graph for cascade _ = lin_v1, lfp_merge_key # merge tables populated - all_chains = [ - chains.cascade(True, direction="down") - for chains in Nwbfile._merge_chains.values() - ] - end_len = [len(chain[0]) for chain in all_chains if chain] + restr_ddp = Nwbfile.ddp(dry_run=True, reload_cache=True) + end_len = [len(ft) for ft in restr_ddp] - assert sum(end_len) == 4, "Merge chains not joined correctly." + assert sum(end_len) >= 8, "Downstream parts not restricted correctly." -def test_get_chain(Nwbfile, pos_merge_tables): +def test_get_downstream_merge(Nwbfile, pos_merge_tables): """Test that the mixin can get the chain of a merge.""" - lin_parts = Nwbfile._get_chain("linear").part_names - lin_output = pos_merge_tables[1] - assert lin_parts == lin_output.parts(), "Chain not found." + lin_output = pos_merge_tables[1].full_table_name + assert lin_output in Nwbfile._part_masters, "Merge not found." @pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") -def test_ddm_warning(Nwbfile, caplog): +def test_ddp_warning(Nwbfile, caplog): """Test that the mixin warns on empty delete_downstream_merge.""" (Nwbfile.file_like("BadName")).delete_downstream_parts( reload_cache=True, disable_warnings=False @@ -70,22 +86,39 @@ def test_ddm_warning(Nwbfile, caplog): assert "No part deletes found" in caplog.text, "No warning issued." -def test_ddm_dry_run(Nwbfile, common, sgp, pos_merge_tables, lin_v1): +def test_ddp_dry_run( + Nwbfile, frequent_imports, common, sgp, pos_merge_tables, lin_v1 +): """Test that the mixin can dry run delete_downstream_merge.""" _ = lin_v1 # merge tables populated + _ = frequent_imports # graph for cascade + pos_output_name = pos_merge_tables[0].full_table_name param_field = "trodes_pos_params_name" trodes_params = sgp.v1.TrodesPosParams() - rft = (trodes_params & f'{param_field} LIKE "%ups%"').ddm( - reload_cache=True, dry_run=True, return_parts=False - )[pos_output_name][0] - assert len(rft) == 1, "ddm did not return restricted table." + rft = [ + table + for table in (trodes_params & f'{param_field} LIKE "%ups%"').ddp( + reload_cache=True, dry_run=True + ) + if table.full_table_name == pos_output_name + ] + assert len(rft) == 1, "ddp did not return restricted table." + + +def test_exp_summary(Nwbfile): + fields = Nwbfile._get_exp_summary().heading.names + expected = ["nwb_file_name", "lab_member_name"] + assert fields == expected, "Exp summary fields not as expected." - table_name = [p for p in pos_merge_tables[0].parts() if "trode" in p][0] - assert table_name == rft.full_table_name, "ddm didn't grab right table." - assert ( - rft.fetch1(param_field) == "single_led_upsampled" - ), "ddm didn't grab right row." +def test_cautious_del_dry_run(Nwbfile, frequent_imports): + _ = frequent_imports # part of cascade, need import + ret = Nwbfile.cdel(dry_run=True) + part_master_names = [t.full_table_name for t in ret[0]] + part_masters = Nwbfile._part_masters + assert all( + [pm in part_masters for pm in part_master_names] + ), "Non part masters found in cautious delete dry run." From 226d4915b09c3f11e5783d447a9f0c3c1f4b4db6 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Mon, 10 Jun 2024 11:20:24 -0500 Subject: [PATCH 05/14] Update changelog --- CHANGELOG.md | 2 ++ src/spyglass/common/common_interval.py | 2 +- src/spyglass/common/common_usage.py | 5 ++-- src/spyglass/utils/dj_graph.py | 41 ++++++++++++-------------- src/spyglass/utils/dj_mixin.py | 1 - 5 files changed, 24 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 54d6f087b..9222997c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ - Add long-distance restrictions via `<<` and `>>` operators. #943, #969 - Fix relative pathing for `mkdocstring-python=>1.9.1`. #967, #968 - Clean up old `TableChain.join` call in mixin delete. #982 +- Expand `delete_downstream_merge` -> `delete_downstream_parts`. #1002 +- `cautious_delete` now checks `IntervalList` and externals tables. #1002 ### Pipelines diff --git a/src/spyglass/common/common_interval.py b/src/spyglass/common/common_interval.py index 56b3a15b5..6e4d6b042 100644 --- a/src/spyglass/common/common_interval.py +++ b/src/spyglass/common/common_interval.py @@ -157,7 +157,7 @@ def nightly_cleanup(self, dry_run=True): orphans = self - get_child_tables(self) if dry_run: return orphans - orphans.super_delete() + orphans.super_delete(warn=False) def intervals_by_length(interval_list, min_length=0.0, max_length=1e10): diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index 5dca00185..b9aa2f528 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -98,9 +98,8 @@ def insert1_return_pk(self, key: dict, **kwargs) -> int: export_id = query.fetch1("export_id") export_key = {"export_id": export_id} if query := (Export & export_key): - if test_mode: - query.super_delete(warn=False, safemode=False) - query.super_delete(warn=False) + safemode = False if test_mode else None # No prompt in tests + query.super_delete(warn=False, safemode=safemode) logger.info(f"{status} {export_key}") return export_id diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 74f9e98b1..8f561c061 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -261,15 +261,6 @@ def _get_ft(self, table, with_restr=False, warn=True): return ft & restr - # ------------------------------ Ignore Nodes ------------------------------ - - def _ignore_peripheral(self, except_tables: List[str] = None): - """Ignore peripheral tables in graph traversal.""" - except_tables = self._ensure_names(except_tables) - ignore_tables = set(PERIPHERAL_TABLES) - set(except_tables or []) - self.no_visit.update(ignore_tables) - self.undirect_graph.remove_nodes_from(ignore_tables) - # ---------------------------- Graph Traversal ----------------------------- def _bridge_restr( @@ -573,9 +564,6 @@ def __init__( Default False verbose : bool, optional Whether to print verbose output. Default False - ignore_peripheral : bool, optional - Whether to ignore peripheral tables in graph traversal. Default - False """ super().__init__(seed_table, verbose=verbose) @@ -869,6 +857,15 @@ def __init__( self.cascade(restriction=search_restr) self.cascaded = True + # ------------------------------ Ignore Nodes ------------------------------ + + def _ignore_peripheral(self, except_tables: List[str] = None): + """Ignore peripheral tables in graph traversal.""" + except_tables = self._ensure_names(except_tables) + ignore_tables = set(PERIPHERAL_TABLES) - set(except_tables or []) + self.no_visit.update(ignore_tables) + self.undirect_graph.remove_nodes_from(ignore_tables) + # --------------------------- Dunder Properties --------------------------- def __str__(self): @@ -970,20 +967,18 @@ def cascade_search(self) -> None: + f"Restr: {restriction}" ) - def _and_parts(self, table): - """Return table, its master and parts.""" - ret = [table] - if master := get_master(table): - ret.append(master) - if parts := self._get_ft(table).parts(): - ret.extend(parts) - return ret - def _set_found_vars(self, table): """Set found_restr and searched_tables.""" self._set_restr(table, self.search_restr, replace=True) self.found_restr = True - self.searched_tables.update(set(self._and_parts(table))) + + and_parts = set([table]) + if master := get_master(table): + and_parts.add(master) + if parts := self._get_ft(table).parts(): + and_parts.update(parts) + + self.searched_tables.update(and_parts) if self.direction == Direction.UP: self.parent = table @@ -1121,6 +1116,8 @@ def cascade( ) # Cascade will stop if any restriction is empty, so set rest to None + # This would cause issues if we want a table partway through the chain + # but that's not a typical use case, were the start and end are desired non_numeric = [t for t in self.path if not t.isnumeric()] if any(self._get_restr(t) is None for t in non_numeric): for table in non_numeric: diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index bad393065..d5dfaa5cc 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -879,7 +879,6 @@ def restrict_by( direction=direction, search_restr=restriction, banned_tables=list(self._banned_search_tables), - allow_merge=True, cascade=True, verbose=verbose, **kwargs, From 2b66b28f6f5318d3fbbbd1451dbea76a245f46ef Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Mon, 10 Jun 2024 13:56:44 -0500 Subject: [PATCH 06/14] Fix typo --- docs/src/misc/mixin.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/misc/mixin.md b/docs/src/misc/mixin.md index 5f53845d6..23135d3c4 100644 --- a/docs/src/misc/mixin.md +++ b/docs/src/misc/mixin.md @@ -141,7 +141,7 @@ order. The mixin provides a function, `delete_downstream_parts`, to handle this, which is run by default when calling `delete`. `delete_downstream_parts`, also aliased as `ddp`, identifies all part tables -with foreign key references downsteam of where it is called. If `dry_run=True`, +with foreign key references downstream of where it is called. If `dry_run=True`, it will return a list of entries that would be deleted, otherwise it will delete them. From 24caf64607f005bda30128b07a53ebb1b88500d0 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Mon, 10 Jun 2024 17:07:17 -0500 Subject: [PATCH 07/14] WIP: topological sort of deletes --- src/spyglass/decoding/v0/core.py | 1 + src/spyglass/utils/dj_graph.py | 62 +++++++++++++++--------------- src/spyglass/utils/dj_helper_fn.py | 9 +++-- src/spyglass/utils/dj_mixin.py | 5 ++- 4 files changed, 42 insertions(+), 35 deletions(-) diff --git a/src/spyglass/decoding/v0/core.py b/src/spyglass/decoding/v0/core.py index 5664c12d9..3df82f318 100644 --- a/src/spyglass/decoding/v0/core.py +++ b/src/spyglass/decoding/v0/core.py @@ -13,6 +13,7 @@ ObservationModel, ) except ImportError as e: + RandomWalk, Uniform, Environment, ObservationModel = None, None, None, None logger.warning(e) from spyglass.common.common_behav import PositionIntervalMap, RawPosition diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 8f561c061..3e010e22b 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -30,7 +30,6 @@ fuzzy_get, unique_dicts, ) -from spyglass.utils.dj_merge_tables import is_merge_table class Direction(Enum): @@ -448,6 +447,30 @@ def cascade1( # ---------------------------- Graph Properties ---------------------------- + def _topo_sort( + self, nodes: List[str], subgraph: bool = True, reverse: bool = False + ) -> List[str]: + """Return topologically sorted list of nodes. + + Parameters + ---------- + nodes : List[str] + List of table names + subgraph : bool, optional + Whether to use subgraph. Default True + reverse : bool, optional + Whether to reverse the order. Default False. If true, bottom-up. + If None, return nodes as is. + """ + nodes = self._ensure_names(nodes) + if reverse is None: + return nodes + graph = self.graph.subgraph(nodes) if subgraph else self.graph + ordered = unite_master_parts(list(topological_sort(graph))) + if reverse: + ordered.reverse() + return [n for n in ordered if n in nodes] + @property def all_ft(self): """Get restricted FreeTables from all visited nodes. @@ -456,14 +479,10 @@ def all_ft(self): """ self.cascade(warn=False) nodes = [n for n in self.visited if not n.isnumeric()] - sorted_nodes = unite_master_parts( - list(topological_sort(self.graph.subgraph(nodes))) - ) - ret = [ + return [ self._get_ft(table, with_restr=True, warn=False) - for table in sorted_nodes + for table in self._topo_sort(nodes, subgraph=True, reverse=False) ] - return ret @property def restr_ft(self): @@ -474,7 +493,7 @@ def ft_from_list( self, tables: List[str], with_restr: bool = True, - sort_from: str = None, + sort_reverse: bool = None, return_empty: bool = False, ) -> List[FreeTable]: """Return non-empty FreeTable objects from list of table names. @@ -485,34 +504,17 @@ def ft_from_list( List of table names with_restr : bool, optional Restrict FreeTable to restriction. Default True. - sort_from : str, optional - Table name. Sort by decreasing distance from this table. - Default None, no sort. + sort_reverse : bool, optional + Sort reverse topologically. Default True. If None, no sort. """ - def graph_distance(self, table1: str = None, table2: str = None) -> int: - """Sort tables by distance from root. If no root, do nothing.""" - if not table1 or not table2: - return 0 - try: - return len(shortest_path(self.undirect_graph, table1, table2)) - except (NodeNotFound, NetworkXNoPath): - return 99 - self.cascade(warn=False) - tables = [self._ensure_names(t) for t in tables] - - if sort_from: - sort_from = self._ensure_names(sort_from) - tables = sorted( - tables, - key=lambda t: graph_distance(sort_from, t), - reverse=True, # sort from farthest to closest - ) fts = [ self._get_ft(table, with_restr=with_restr, warn=False) - for table in tables + for table in self._topo_sort( + tables, subgraph=False, reverse=sort_reverse + ) ] return fts if return_empty else [ft for ft in fts if len(ft) > 0] diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 3fa18191c..290d84e54 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -261,7 +261,10 @@ def fetch_nwb(query_expression, nwb_master, *attrs, **kwargs): # skip the filepath checksum if streamed from Dandi rec_dict["nwb2load_filepath"] = file_path continue - rec_dict["nwb2load_filepath"] = (query_table & rec_dict).fetch1( + + # Full dect caused issues with dlc tables using dicts in secondary keys + rec_only_pk = {k: rec_dict[k] for k in query_table.heading.primary_key} + rec_dict["nwb2load_filepath"] = (query_table & rec_only_pk).fetch1( "nwb2load_filepath" ) @@ -331,7 +334,7 @@ def update_analysis_for_dandi_standard( # edit the file with h5py.File(filepath, "a") as file: sex_value = file["/general/subject/sex"][()].decode("utf-8") - if not sex_value in ["Female", "Male", "F", "M", "O", "U"]: + if sex_value not in ["Female", "Male", "F", "M", "O", "U"]: raise ValueError(f"Unexpected value for sex: {sex_value}") if len(sex_value) > 1: @@ -354,7 +357,7 @@ def update_analysis_for_dandi_standard( len(species_value.split(" ")) == 2 or "NCBITaxon" in species_value ): raise ValueError( - f"Dandi upload requires species either be in Latin binomial form (e.g., 'Mus musculus' and 'Homo sapiens')" + "Dandi upload requires species either be in Latin binomial form (e.g., 'Mus musculus' and 'Homo sapiens')" + "or be a NCBI taxonomy link (e.g., 'http://purl.obolibrary.org/obo/NCBITaxon_280675')." + f"\n Please update species value of: {species_value}" ) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index d5dfaa5cc..79fd8447b 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -389,8 +389,9 @@ def delete_downstream_parts( if return_graph: return restr_graph - # Depends on distance as a proxy for downstream-ness of each - down_fts = restr_graph.ft_from_list(self._part_masters, sort_from=self) + down_fts = restr_graph.ft_from_list( + self._part_masters, sort_reverse=True + ) if not down_fts and not disable_warning: logger.warning( From cd3b268501e16efbeac873890615c95f67786178 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Mon, 10 Jun 2024 17:57:36 -0500 Subject: [PATCH 08/14] =?UTF-8?q?=20=E2=9C=85=20:=20Topological=20sort?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/spyglass/utils/dj_helper_fn.py | 16 ++++++++++------ src/spyglass/utils/dj_mixin.py | 3 +-- tests/conftest.py | 20 +++++++++++++++++++- 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 290d84e54..0b9155cc7 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -348,8 +348,9 @@ def update_analysis_for_dandi_standard( species_value = file["/general/subject/species"][()].decode("utf-8") if species_value == "Rat": new_species_value = "Rattus norvegicus" - print( - f"Adjusting subject species from '{species_value}' to '{new_species_value}'." + logger.info( + f"Adjusting subject species from '{species_value}' to " + + f"'{new_species_value}'." ) file["/general/subject/species"][()] = new_species_value @@ -357,9 +358,11 @@ def update_analysis_for_dandi_standard( len(species_value.split(" ")) == 2 or "NCBITaxon" in species_value ): raise ValueError( - "Dandi upload requires species either be in Latin binomial form (e.g., 'Mus musculus' and 'Homo sapiens')" - + "or be a NCBI taxonomy link (e.g., 'http://purl.obolibrary.org/obo/NCBITaxon_280675')." - + f"\n Please update species value of: {species_value}" + "Dandi upload requires species either be in Latin binomial form" + + " (e.g., 'Mus musculus' and 'Homo sapiens') or be a NCBI " + + "taxonomy link (e.g., " + + "'http://purl.obolibrary.org/obo/NCBITaxon_280675').\n " + + f"Please update species value of: {species_value}" ) # add subject age dataset "P4M/P8M" @@ -378,7 +381,8 @@ def update_analysis_for_dandi_standard( if experimenter_value != new_experimenter_value: new_experimenter_value = new_experimenter_value.astype(STR_DTYPE) logger.info( - f"Adjusting experimenter from {experimenter_value} to {new_experimenter_value}." + f"Adjusting experimenter from {experimenter_value} to " + + f"{new_experimenter_value}." ) file["/general/experimenter"][:] = new_experimenter_value diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 79fd8447b..ba4ab6652 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -19,7 +19,6 @@ from spyglass.utils.database_settings import SHARED_MODULES from spyglass.utils.dj_helper_fn import fetch_nwb, get_nwb_table -from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK from spyglass.utils.dj_merge_tables import Merge, is_merge_table from spyglass.utils.logging import logger @@ -390,7 +389,7 @@ def delete_downstream_parts( return restr_graph down_fts = restr_graph.ft_from_list( - self._part_masters, sort_reverse=True + self._part_masters, sort_reverse=False ) if not down_fts and not disable_warning: diff --git a/tests/conftest.py b/tests/conftest.py index 3a52f3e2d..68487a025 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -238,6 +238,8 @@ def nodlc(request): def skipif_nodlc(request): if NO_DLC: yield pytest.mark.skip(reason="Skipping DLC-dependent tests.") + else: + yield @pytest.fixture(scope="session") @@ -842,6 +844,12 @@ def insert_project( from deeplabcut.utils.auxiliaryfunctions import read_config, write_config + from spyglass.decoding.v1.core import PositionGroup + from spyglass.linearization.merge import LinearizedPositionOutput + from spyglass.linearization.v1 import LinearizationSelection + from spyglass.mua.v1.mua import MuaEventsV1 + from spyglass.ripple.v1 import RippleTimesV1 + team_name = "sc_eb" common.LabTeam.insert1({"team_name": team_name}, skip_duplicates=True) with verbose_context: @@ -879,6 +887,13 @@ def insert_project( yield project_key, cfg, config_path if teardown: + + """ + DataJointError: Attempt to delete part table + `position_merge`.`position_output__d_l_c_pos_v1` before deleting from + its master `position_merge`.`position_output` first. + """ + (dlc_project_tbl & project_key).delete(safemode=False) shutil_rmtree(str(Path(config_path).parent)) @@ -1218,7 +1233,10 @@ def populate_orient(sgp, orient_selection): @pytest.fixture(scope="session") -def dlc_selection(sgp, centroid_key, orient_key, populate_orient): +def dlc_selection( + sgp, centroid_key, orient_key, populate_orient, populate_centroid +): + _ = populate_orient, populate_centroid dlc_key = { key: val for key, val in centroid_key.items() From e50cb69193f82e132a9b199e7696d4fb8a5cfaf2 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 11 Jun 2024 12:41:57 -0500 Subject: [PATCH 09/14] Revise downloads --- .gitignore | 1 + pyproject.toml | 4 +-- tests/conftest.py | 62 ++++++++++++---------------------------- tests/data_downloader.py | 61 +++++++++++++++++++++++++-------------- 4 files changed, 62 insertions(+), 66 deletions(-) diff --git a/.gitignore b/.gitignore index f5d7d262f..5b4cb3ad8 100644 --- a/.gitignore +++ b/.gitignore @@ -60,6 +60,7 @@ coverage.xml .hypothesis/ .pytest_cache/ tests/_data/* +wget-log* # Translations *.mo diff --git a/pyproject.toml b/pyproject.toml index 78d189b73..d293b18e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,8 +125,8 @@ ignore-words-list = 'nevers' minversion = "7.0" addopts = [ # "-sv", # no capture, verbose output - # "--sw", # stepwise: resume with next test after failure - # "--pdb", # drop into debugger on failure + "--sw", # stepwise: resume with next test after failure + "--pdb", # drop into debugger on failure "-p no:warnings", # "--no-teardown", # don't teardown the database after tests # "--quiet-spy", # don't show logging from spyglass diff --git a/tests/conftest.py b/tests/conftest.py index 68487a025..39b7c2a84 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,6 @@ from contextlib import nullcontext from pathlib import Path from shutil import rmtree as shutil_rmtree -from time import sleep as tsleep import datajoint as dj import numpy as np @@ -210,20 +209,9 @@ def raw_dir(base_dir): @pytest.fixture(scope="session") def mini_path(raw_dir): path = raw_dir / TEST_FILE + DOWNLOADS.wait_for(TEST_FILE) # wait for wget download to finish - # wait for wget download to finish - if (nwb_download := DOWNLOADS.file_downloads.get(TEST_FILE)) is not None: - nwb_download.wait() - - # wait for download to finish - timeout, wait, found = 60, 5, False - for _ in range(timeout // wait): - if path.exists(): - found = True - break - tsleep(wait) - - if not found: + if not path.exists(): raise ConnectionError("Download failed.") yield path @@ -426,9 +414,9 @@ def frequent_imports(): @pytest.fixture(scope="session") def video_keys(common, base_dir): - for file, download in DOWNLOADS.file_downloads.items(): - if file.endswith(".h264") and download is not None: - download.wait() # wait for videos to finish downloading + for file in DOWNLOADS.file_downloads: + if file.endswith(".h264"): + DOWNLOADS.wait_for(file) DOWNLOADS.rename_files() return common.VideoFile().fetch(as_dict=True) @@ -850,6 +838,14 @@ def insert_project( from spyglass.mua.v1.mua import MuaEventsV1 from spyglass.ripple.v1 import RippleTimesV1 + _ = ( + PositionGroup, + LinearizedPositionOutput, + LinearizationSelection, + MuaEventsV1, + RippleTimesV1, + ) + team_name = "sc_eb" common.LabTeam.insert1({"team_name": team_name}, skip_duplicates=True) with verbose_context: @@ -887,13 +883,6 @@ def insert_project( yield project_key, cfg, config_path if teardown: - - """ - DataJointError: Attempt to delete part table - `position_merge`.`position_output__d_l_c_pos_v1` before deleting from - its master `position_merge`.`position_output` first. - """ - (dlc_project_tbl & project_key).delete(safemode=False) shutil_rmtree(str(Path(config_path).parent)) @@ -943,23 +932,8 @@ def labeled_vid_dir(extract_frames): @pytest.fixture(scope="session") -def fix_downloaded(labeled_vid_dir, project_dir): - """Grabs CollectedData and img files from project_dir, moves to labeled""" - for file in project_dir.parent.parent.glob("*"): - if file.is_dir(): - continue - dest = labeled_vid_dir / file.name - if dest.exists(): - dest.unlink() - dest.write_bytes(file.read_bytes()) - # TODO: revert to rename before merge - # file.rename(labeled_vid_dir / file.name) - - yield - - -@pytest.fixture(scope="session") -def add_training_files(dlc_project_tbl, project_key, fix_downloaded): +def add_training_files(dlc_project_tbl, project_key, labeled_vid_dir): + DOWNLOADS.move_dlc_items(labeled_vid_dir) dlc_project_tbl.add_training_files(project_key, skip_duplicates=True) yield @@ -1009,11 +983,13 @@ def model_train_key(sgp, project_key, training_params_key): @pytest.fixture(scope="session") -def populate_training(sgp, fix_downloaded, model_train_key, add_training_files): +def populate_training( + sgp, model_train_key, add_training_files, labeled_vid_dir +): train_tbl = sgp.v1.DLCModelTraining if len(train_tbl & model_train_key) == 0: _ = add_training_files - _ = fix_downloaded + DOWNLOADS.move_dlc_items(labeled_vid_dir) sgp.v1.DLCModelTraining.populate(model_train_key) yield model_train_key diff --git a/tests/data_downloader.py b/tests/data_downloader.py index 98a254eda..7be7135de 100644 --- a/tests/data_downloader.py +++ b/tests/data_downloader.py @@ -1,10 +1,14 @@ from functools import cached_property from os import environ as os_environ from pathlib import Path +from shutil import copy as shutil_copy from subprocess import DEVNULL, Popen from sys import stderr, stdout +from time import sleep as time_sleep from typing import Dict, Union +from datajoint import logger as dj_logger + UCSF_BOX_USER = os_environ.get("UCSF_BOX_USER") UCSF_BOX_TOKEN = os_environ.get("UCSF_BOX_TOKEN") BASE_URL = "ftps://ftp.box.com/trodes_to_nwb_test_data/" @@ -87,6 +91,7 @@ def __init__( self.cmd_kwargs = dict(stdout=stdout, stderr=stderr) self.base_dir = Path(base_dir).resolve() + self.download_dlc = download_dlc self.file_paths = file_paths if download_dlc else file_paths[:NON_DLC] self.base_dir.mkdir(exist_ok=True) @@ -112,28 +117,42 @@ def file_downloads(self) -> Dict[str, Union[Popen, None]]: for path in self.file_paths: target, url = path["target_name"], path["url"] target_dir = self.base_dir / path["relative_dir"] + target_dir.mkdir(exist_ok=True, parents=True) dest = target_dir / target + cmd = ( + ["echo", f"Already have {target}"] + if dest.exists() + else self.cmd + [target_dir, url] + ) + ret[target] = Popen(cmd, **self.cmd_kwargs) + return ret - if dest.exists(): - ret[target] = None - continue + def wait_for(self, target: str): + """Wait for target to finish downloading.""" + status = self.file_downloads.get(target).poll() + limit = 10 + while status is None and limit > 0: + time_sleep(5) # Some + limit -= 1 + status = self.file_downloads.get(target).poll() + if status != 0: + raise ValueError(f"Error downloading: {target}") + if limit < 1: + raise TimeoutError(f"Timeout downloading: {target}") - target_dir.mkdir(exist_ok=True, parents=True) - ret[target] = Popen(self.cmd + [target_dir, url], **self.cmd_kwargs) - return ret + def move_dlc_items(self, dest_dir: Path): + """Move completed DLC files to dest_dir.""" + if not self.download_dlc: + return + if not dest_dir.exists(): + dest_dir.mkdir(parents=True) - def check_download(self, download, info): - if download is not None: - download.wait() - if download.returncode: - return download - return None - - @property - def download_errors(self): - ret = [] - for download, item in zip(self.file_downloads, self.file_paths): - if d_status := self.check_download(download, item): - ret.append(d_status) - continue - return ret + for path in self.file_paths[NON_DLC:]: + target = path["target_name"] + self.wait_for(target) # Could be faster if moved finished first + + src_path = self.base_dir / path["relative_dir"] / target + dest_path = dest_dir / src_path.name + if not dest_path.exists(): + shutil_copy(str(src_path), str(dest_path)) + dj_logger.info(f"Moved: {src_path} -> {dest_path}") From 3d775c33321d868a58bbd2252021b5fe4d3ebd2f Mon Sep 17 00:00:00 2001 From: Chris Broz Date: Fri, 14 Jun 2024 11:41:49 -0500 Subject: [PATCH 10/14] Update src/spyglass/utils/dj_helper_fn.py Co-authored-by: Samuel Bray --- src/spyglass/utils/dj_helper_fn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 0b9155cc7..1be5e9796 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -262,7 +262,7 @@ def fetch_nwb(query_expression, nwb_master, *attrs, **kwargs): rec_dict["nwb2load_filepath"] = file_path continue - # Full dect caused issues with dlc tables using dicts in secondary keys + # Full dict caused issues with dlc tables using dicts in secondary keys rec_only_pk = {k: rec_dict[k] for k in query_table.heading.primary_key} rec_dict["nwb2load_filepath"] = (query_table & rec_only_pk).fetch1( "nwb2load_filepath" From 824894c3e7a6a492cd28b269bc8bc31d5def0ef1 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 20 Jun 2024 16:30:20 -0500 Subject: [PATCH 11/14] Ignore unimported non-spyglass in cascade --- src/spyglass/utils/dj_graph.py | 23 +++++++++++++++++++++-- src/spyglass/utils/dj_mixin.py | 15 ++++++++++++--- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 3e010e22b..3e90d4736 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -25,6 +25,7 @@ from tqdm import tqdm from spyglass.utils import logger +from spyglass.utils.database_settings import SHARED_MODULES from spyglass.utils.dj_helper_fn import ( PERIPHERAL_TABLES, fuzzy_get, @@ -260,6 +261,16 @@ def _get_ft(self, table, with_restr=False, warn=True): return ft & restr + def _is_out(self, table, warn=True): + """Check if table is outside of spyglass.""" + table = self._ensure_names(table) + if self.graph.nodes.get(table): + return False + ret = table.split(".")[0].split("_")[0].strip("`") not in SHARED_MODULES + if warn and ret: # Log warning if outside + logger.warning(f"Skipping unimported: {table}") + return ret + # ---------------------------- Graph Traversal ----------------------------- def _bridge_restr( @@ -299,15 +310,19 @@ def _bridge_restr( List[Dict[str, str]] List of dicts containing primary key fields for restricted table2. """ + if self._is_out(table2) or self._is_out(table1): # 2 more likely + return ["False"] # Stop cascade if outside, see #1002 + if not all([direction, attr_map]): dir_bool, edge = self._get_edge(table1, table2) direction = "up" if dir_bool else "down" attr_map = edge.get("attr_map") + # May return empty table if outside imported and outside spyglass ft1 = self._get_ft(table1) & restr ft2 = self._get_ft(table2) - if len(ft1) == 0: + if len(ft1) == 0 or len(ft2) == 0: return ["False"] if bool(set(attr_map.values()) - set(ft1.heading.names)): @@ -462,9 +477,13 @@ def _topo_sort( Whether to reverse the order. Default False. If true, bottom-up. If None, return nodes as is. """ - nodes = self._ensure_names(nodes) if reverse is None: return nodes + nodes = [ + node + for node in self._ensure_names(nodes) + if not self._is_out(node, warn=False) + ] graph = self.graph.subgraph(nodes) if subgraph else self.graph ordered = unite_master_parts(list(topological_sort(graph))) if reverse: diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index ba4ab6652..c282b5e1e 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -239,14 +239,18 @@ def fetch_pynapple(self, *attrs, **kwargs): # ------------------------ delete_downstream_parts ------------------------ - def _import_merge_tables(self): + def _import_part_masters(self): """Import all merge tables.""" from spyglass.decoding.decoding_merge import DecodingOutput # noqa F401 + from spyglass.decoding.v1.core import PositionGroup # noqa F401 from spyglass.lfp.lfp_merge import LFPOutput # noqa F401 from spyglass.linearization.merge import ( LinearizedPositionOutput, ) # noqa F401 from spyglass.position.position_merge import PositionOutput # noqa F401 + from spyglass.spikesorting.analysis.v1.group import ( # noqa F401 + SortedSpikesGroup, + ) from spyglass.spikesorting.spikesorting_merge import ( # noqa F401 SpikeSortingOutput, ) @@ -255,7 +259,9 @@ def _import_merge_tables(self): DecodingOutput(), LFPOutput(), LinearizedPositionOutput(), + PositionGroup(), PositionOutput(), + SortedSpikesGroup(), SpikeSortingOutput(), ) @@ -286,7 +292,7 @@ def search_descendants(parent): _ = search_descendants(self) except NetworkXError: try: # Attempt to import missing table - self._import_merge_tables() + self._import_part_masters() _ = search_descendants(self) except NetworkXError as e: table_name = "".join(e.args[0].split("`")[1:4]) @@ -343,6 +349,7 @@ def delete_downstream_parts( reload_cache: bool = False, disable_warning: bool = False, return_graph: bool = False, + verbose: bool = False, **kwargs, ) -> List[dj.FreeTable]: """Delete downstream merge table entries associated with restriction. @@ -365,6 +372,8 @@ def delete_downstream_parts( If True, return RestrGraph object used to identify downstream tables. Default False, return list of part FreeTables. True. If False, return dictionary of merge tables and their joins. + verbose : bool, optional + If True, call RestrGraph with verbose=True. Default False. **kwargs : Any Passed to datajoint.table.Table.delete. """ @@ -382,7 +391,7 @@ def delete_downstream_parts( leaves={self.full_table_name: restriction}, direction="down", cascade=True, - verbose=False, + verbose=verbose, ) if return_graph: From 1112e2ef95f051a21e8c16e5563c5e5b3ad49a45 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Mon, 24 Jun 2024 13:56:25 -0500 Subject: [PATCH 12/14] Load part-master cache before graph --- src/spyglass/utils/dj_mixin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index dc1a888cf..877bcf40f 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -390,6 +390,7 @@ def delete_downstream_parts( if reload_cache: _ = self.__dict__.pop("_part_masters", None) + _ = self._part_masters # load cache before loading graph restriction = restriction or self.restriction or True restr_graph = RestrGraph( From ba9333adfc43da8eb23ede3b5356b24d76922c50 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 25 Jun 2024 13:54:23 -0500 Subject: [PATCH 13/14] Add more automatic imports --- src/spyglass/utils/dj_mixin.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 877bcf40f..fd1e48dec 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -246,29 +246,53 @@ def fetch_pynapple(self, *attrs, **kwargs): # ------------------------ delete_downstream_parts ------------------------ def _import_part_masters(self): - """Import all merge tables.""" + """Import tables that may constrain a RestrGraph. See #1002""" + from spyglass.common.common_ripple import ( + RippleLFPSelection, + ) # noqa F401 from spyglass.decoding.decoding_merge import DecodingOutput # noqa F401 + from spyglass.decoding.v0.clusterless import ( # noqa F401 + UnitMarksIndicatorSelection, + ) + from spyglass.decoding.v0.sorted_spikes import ( # noqa F401 + SortedSpikesIndicatorSelection, + ) from spyglass.decoding.v1.core import PositionGroup # noqa F401 + from spyglass.lfp.analysis.v1 import LFPBandSelection # noqa F401 from spyglass.lfp.lfp_merge import LFPOutput # noqa F401 - from spyglass.linearization.merge import ( + from spyglass.linearization.merge import ( # noqa F401 LinearizedPositionOutput, - ) # noqa F401 + LinearizedPositionV1, + ) + from spyglass.mua.v1.mua import MuaEventsV1 # noqa F401 from spyglass.position.position_merge import PositionOutput # noqa F401 + from spyglass.ripple.v1.ripple import RippleTimesV1 # noqa F401 from spyglass.spikesorting.analysis.v1.group import ( # noqa F401 SortedSpikesGroup, ) from spyglass.spikesorting.spikesorting_merge import ( # noqa F401 SpikeSortingOutput, ) + from spyglass.spikesorting.v0.figurl_views import ( # noqa F401 + SpikeSortingRecordingView, + ) _ = ( DecodingOutput(), + LFPBandSelection(), LFPOutput(), LinearizedPositionOutput(), + LinearizedPositionV1(), + MuaEventsV1(), PositionGroup(), PositionOutput(), + RippleLFPSelection(), + RippleTimesV1(), SortedSpikesGroup(), + SortedSpikesIndicatorSelection(), SpikeSortingOutput(), + SpikeSortingRecordingView(), + UnitMarksIndicatorSelection(), ) @cached_property From 01d361f263e029ce74d8a4c1852ee6f9343b2f80 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 25 Jun 2024 15:06:21 -0500 Subject: [PATCH 14/14] Pin twine req for build --- .github/workflows/test-package-build.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test-package-build.yml b/.github/workflows/test-package-build.yml index 41aace719..c93b77398 100644 --- a/.github/workflows/test-package-build.yml +++ b/.github/workflows/test-package-build.yml @@ -27,7 +27,9 @@ jobs: - uses: actions/setup-python@v5 with: python-version: 3.9 - - run: pip install --upgrade build twine + - run: | + pip install --upgrade build twine + pip install importlib_metadata==7.2.1 # twine #977 - name: Build sdist and wheel run: python -m build - run: twine check dist/*