Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow disable transaction for select populates #1067

Merged
merged 10 commits into from
Aug 29, 2024
17 changes: 14 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# Change Log

## [0.5.4] (Unreleased)

### Release Notes

<!-- Running draft to be removed immediately prior to release. -->

### Infrastructure

- Disable populate transaction protection for long-populating tables #1066

## [0.5.3] (August 27, 2024)

### Infrastructure
Expand All @@ -25,9 +35,9 @@
- Allow `ModuleNotFoundError` or `ImportError` for optional dependencies #1023
- Ensure integrity of group tables #1026
- Convert list of LFP artifact removed interval list to array #1046
- Merge duplicate functions in decoding and spikesorting #1050, #1053, #1058,
#1066
- Revise docs organization.
- Merge duplicate functions in decoding and spikesorting #1050, #1053, #1062,
#1066, #1069
- Reivise docs organization.
- Misc -> Features/ForDevelopers. #1029
- Installation instructions -> Setup notebook. #1029
- Migrate SQL export tools to `utils` to support exporting `DandiPath` #1048
Expand Down Expand Up @@ -320,3 +330,4 @@
[0.5.1]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.5.1
[0.5.2]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.5.2
[0.5.3]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.5.3
[0.5.4]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.5.4
30 changes: 30 additions & 0 deletions docs/src/Features/Mixin.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,33 @@ nwbfile = Nwbfile()
(nwbfile & "nwb_file_name LIKE 'Name%'").ddp(dry_run=False)
(nwbfile & "nwb_file_name LIKE 'Other%'").ddp(dry_run=False)
```

## Populate Calls

The mixin also overrides the default `populate` function to provide additional
functionality for non-daemon process pools and disabling transaction protection.

### Non-Daemon Process Pools

To allow the `make` function to spawn a new process pool, the mixin overrides
the default `populate` function for tables with `_parallel_make` set to `True`.
See [issue #1000](https://github.com/LorenFrankLab/spyglass/issues/1000) and
[PR #1001](https://github.com/LorenFrankLab/spyglass/pull/1001) for more
information.

### Disable Transaction Protection

By default, DataJoint wraps the `populate` function in a transaction to ensure
data integrity (see
[Transactions](https://docs.datajoint.io/python/definition/05-Transactions.html)).

This can cause issues when populating large tables if another user attempts to
declare/modify a table while the transaction is open (see
[issue #1030](https://github.com/LorenFrankLab/spyglass/issues/1030) and
[DataJoint issue #1170](https://github.com/datajoint/datajoint-python/issues/1170)).

Tables with `_use_transaction` set to `False` will not be wrapped in a
transaction when calling `populate`. Transaction protection is replaced by a
hash of upstream data to ensure no changes are made to the table during the
unprotected populate. The additional time required to hash the data is a
trade-off for already time-consuming populates, but avoids blocking other users.
2 changes: 2 additions & 0 deletions src/spyglass/position/v1/position_dlc_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ class DLCModelTraining(SpyglassMixin, dj.Computed):
latest_snapshot: int unsigned # latest exact snapshot index (i.e., never -1)
config_template: longblob # stored full config file
"""

log_path = None
_use_transaction, _allow_insert = False, True

# To continue from previous training snapshot,
# devs suggest editing pose_cfg.yml
Expand Down
2 changes: 2 additions & 0 deletions src/spyglass/spikesorting/v1/figurl_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ class FigURLCuration(SpyglassMixin, dj.Computed):
url: varchar(1000)
"""

_use_transaction, _allow_insert = False, True

def make(self, key: dict):
# FETCH
query = (
Expand Down
2 changes: 2 additions & 0 deletions src/spyglass/spikesorting/v1/metric_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ class MetricCuration(SpyglassMixin, dj.Computed):
object_id: varchar(40) # Object ID for the metrics in NWB file
"""

_use_transaction, _allow_insert = False, True

def make(self, key):
AnalysisNwbfile()._creation_times["pre_create_time"] = time()
# FETCH
Expand Down
2 changes: 2 additions & 0 deletions src/spyglass/spikesorting/v1/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ class SpikeSorting(SpyglassMixin, dj.Computed):
time_of_sort: int # in Unix time, to the nearest second
"""

_use_transaction, _allow_insert = False, True

def make(self, key: dict):
"""Runs spike sorting on the data and parameters specified by the
SpikeSortingSelection table and inserts a new entry to SpikeSorting table.
Expand Down
9 changes: 9 additions & 0 deletions src/spyglass/utils/dj_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from copy import deepcopy
from enum import Enum
from functools import cached_property
from hashlib import md5 as hash_md5
from itertools import chain as iter_chain
from typing import Any, Dict, Iterable, List, Set, Tuple, Union

Expand Down Expand Up @@ -595,6 +596,14 @@ def leaf_ft(self):
"""Get restricted FreeTables from graph leaves."""
return [self._get_ft(table, with_restr=True) for table in self.leaves]

@property
def hash(self):
"""Return hash of all visited nodes."""
initial = hash_md5(b"")
for table in self.all_ft:
initial.update(table.fetch())
return initial.hexdigest()

# ------------------------------- Add Nodes -------------------------------

def add_leaf(
Expand Down
3 changes: 2 additions & 1 deletion src/spyglass/utils/dj_helper_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,8 @@ def make_file_obj_id_unique(nwb_path: str):
def populate_pass_function(value):
"""Pass function for parallel populate.

Note: To avoid pickling errors, the table must be passed by class, NOT by instance.
Note: To avoid pickling errors, the table must be passed by class,
NOT by instance.
Note: This function must be defined in the global namespace.

Parameters
Expand Down
118 changes: 103 additions & 15 deletions src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ class SpyglassMixin:
_banned_search_tables = set() # Tables to avoid in restrict_by
_parallel_make = False # Tables that use parallel processing in make

_use_transaction = True # Use transaction in populate.

def __init__(self, *args, **kwargs):
"""Initialize SpyglassMixin.

Expand Down Expand Up @@ -410,7 +412,7 @@ def delete_downstream_parts(
**kwargs : Any
Passed to datajoint.table.Table.delete.
"""
from spyglass.utils.dj_graph import RestrGraph # noqa F401
RestrGraph = self._graph_deps[1]

start = time()

Expand Down Expand Up @@ -475,7 +477,14 @@ def _delete_deps(self) -> List[Table]:
self._member_pk = LabMember.primary_key[0]
return [LabMember, LabTeam, Session, schema.external, IntervalList]

def _get_exp_summary(self) -> Union[QueryExpression, None]:
@cached_property
def _graph_deps(self) -> list:
from spyglass.utils.dj_graph import RestrGraph # noqa #F401
from spyglass.utils.dj_graph import TableChain

return [TableChain, RestrGraph]

def _get_exp_summary(self):
"""Get summary of experimenters for session(s), including NULL.

Parameters
Expand Down Expand Up @@ -513,7 +522,7 @@ def _get_exp_summary(self) -> Union[QueryExpression, None]:
@cached_property
def _session_connection(self):
"""Path from Session table to self. False if no connection found."""
from spyglass.utils.dj_graph import TableChain # noqa F401
TableChain = self._graph_deps[0]

return TableChain(parent=self._delete_deps[2], child=self, verbose=True)

Expand Down Expand Up @@ -697,25 +706,104 @@ def super_delete(self, warn=True, *args, **kwargs):
self._log_delete(start=time(), super_delete=True)
super().delete(*args, **kwargs)

# -------------------------- non-daemon populate --------------------------
# -------------------------------- populate --------------------------------

def _hash_upstream(self, keys):
"""Hash upstream table keys for no transaction populate.

Uses a RestrGraph to capture all upstream tables, restrict them to
relevant entries, and hash the results. This is used to check if
upstream tables have changed during a no-transaction populate and avoid
the following data-integrity error:

1. User A starts no-transaction populate.
2. User B deletes and repopulates an upstream table, changing contents.
3. User A finishes populate, inserting data that is now invalid.

Parameters
----------
keys : list
List of keys for populating table.
"""
RestrGraph = self._graph_deps[1]

if not (parents := self.parents(as_objects=True, primary=True)):
raise RuntimeError("No upstream tables found for upstream hash.")

leaves = { # Restriction on each primary parent
p.full_table_name: [
{k: v for k, v in key.items() if k in p.heading.names}
for key in keys
]
for p in parents
}

return RestrGraph(seed_table=self, leaves=leaves, cascade=True).hash

def populate(self, *restrictions, **kwargs):
"""Populate table in parallel.
"""Populate table in parallel, with or without transaction protection.

Supersedes datajoint.table.Table.populate for classes with that
spawn processes in their make function
spawn processes in their make function and always use transactions.

`_use_transaction` class attribute can be set to False to disable
transaction protection for a table. This is not recommended for tables
with short processing times. A before-and-after hash check is performed
to ensure upstream tables have not changed during populate, and may
be a more time-consuming process. To permit the `make` to insert without
populate, set `_allow_insert` to True.
"""

# Pass through to super if not parallel in the make function or only a single process
processes = kwargs.pop("processes", 1)

# Decide if using transaction protection
use_transact = kwargs.pop("use_transation", None)
if use_transact is None: # if user does not specify, use class default
use_transact = self._use_transaction
if self._use_transaction is False: # If class default is off, warn
logger.warning(
"Turning off transaction protection this table by default. "
+ "Use use_transation=True to re-enable.\n"
+ "Read more about transactions:\n"
+ "https://docs.datajoint.io/python/definition/05-Transactions.html\n"
+ "https://github.com/LorenFrankLab/spyglass/issues/1030"
)
if use_transact is False and processes > 1:
samuelbray32 marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(
"Must use transaction protection with parallel processing.\n"
+ "Call with use_transation=True.\n"
+ f"Table default transaction use: {self._use_transaction}"
)

# Get keys, needed for no-transact or multi-process w/_parallel_make
keys = [True]
if use_transact is False or (processes > 1 and self._parallel_make):
keys = (self._jobs_to_do(restrictions) - self.target).fetch(
"KEY", limit=kwargs.get("limit", None)
)

if use_transact is False:
upstream_hash = self._hash_upstream(keys)
if kwargs: # Warn of ignoring populate kwargs, bc using `make`
logger.warning(
"Ignoring kwargs when not using transaction protection."
)

if processes == 1 or not self._parallel_make:
kwargs["processes"] = processes
return super().populate(*restrictions, **kwargs)
if use_transact: # Pass single-process populate to super
kwargs["processes"] = processes
return super().populate(*restrictions, **kwargs)
else: # No transaction protection, use bare make
for key in keys:
self.make(key)
if upstream_hash != self._hash_upstream(keys):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could be made simpler for the user to clean up the mismatched hash insert by either:

  • printing the key in the raised error
  • automatically deleting the key before the error (enforces integrity)
  • pass the hash result to the make function and do the check in there before insert if there's a hash (enforces integrity, requires more edits to the code)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only using the hash makes it tough to determine where the mismatch occurred. I made a commit to delete all the keys and ask the user to start over. It's not ideal, but, from the code I've seen, folks primarily run one key at a time. It seems like an unlikely case that there would be a mismatch - and that it'll have a serious impact in the timeline between now and a new DJ feature we can use

Would you rather I run a table-wise comparison across the two RestrGraph objects to nail down where the mismatch occurred?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the overly cautious solution of delete them all is fine for now. As you said, it's unlikely that it would occur, and ensuring consistency is a valid priority. User's shouldn't notice it, and if they do we would need to be going back into the code to figure out why anyways

(self & keys).delete(force=True)
logger.error(
"Upstream tables changed during non-transaction "
+ "populate. Please try again."
)
return

# If parallel in both make and populate, use non-daemon processes
# Get keys to populate
keys = (self._jobs_to_do(restrictions) - self.target).fetch(
"KEY", limit=kwargs.get("limit", None)
)
# package the call list
call_list = [(type(self), key, kwargs) for key in keys]

Expand Down Expand Up @@ -964,7 +1052,7 @@ def restrict_by(
Restricted version of present table or TableChain object. If
return_graph, use all_ft attribute to see all tables in cascade.
"""
from spyglass.utils.dj_graph import TableChain # noqa: F401
TableChain = self._graph_deps[0]

if restriction is True:
return self
Expand Down
Loading