Skip to content

Commit

Permalink
WIP: start add V0 hasher
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Mar 5, 2025
1 parent 72f8a25 commit 9c27d87
Show file tree
Hide file tree
Showing 11 changed files with 415 additions and 75 deletions.
13 changes: 8 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Change Log

## [0.5.5] (Unreleased)
## \[0.5.5\] (Unreleased)

### Release Notes

Expand All @@ -10,14 +10,17 @@

```python
import datajoint as dj
from spyglass.spikesorting.v1.recording import * # noqa
from spyglass.spikesorting.v1 import recording as v1rec # noqa
from spyglass.spikesorting.v0 import spikesorting_recording as v0rec # noqa
from spyglass.linearization.v1.main import * # noqa

dj.FreeTable(dj.conn(), "common_nwbfile.analysis_nwbfile_log").drop()
dj.FreeTable(dj.conn(), "common_session.session_group").drop()
TrackGraph.alter() # Add edge map parameter
SpikeSortingRecording().alter()
SpikeSortingRecording().update_ids()
v0rec.SpikeSortingRecording().alter()
v0rec.SpikeSortingRecording().update_ids()
v1rec.SpikeSortingRecording().alter()
v1rec.SpikeSortingRecording().update_ids()
```

### Infrastructure
Expand All @@ -28,7 +31,7 @@ SpikeSortingRecording().update_ids()
- Improve cron job documentation and script #1226, #1241
- Update export process to include `~external` tables #1239
- Only add merge parts to `source_class_dict` if present in codebase #1237
- Add recompute ability for `SpikeSortingRecording` #1093
- Add recompute ability for `SpikeSortingRecording` #1093

### Pipelines

Expand Down
1 change: 1 addition & 0 deletions docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ nav:
- Merge Tables: Features/Merge.md
- Export: Features/Export.md
- Centralized Code: Features/Mixin.md
- Recompute: Features/Recompute.md
- For Developers:
- Overview: ForDevelopers/index.md
- How to Contribute: ForDevelopers/Contribute.md
Expand Down
55 changes: 55 additions & 0 deletions docs/src/Features/Recompute.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Recompute

## Why

Some analysis files that are generated by Spyglass are very unlikely to be
reaccessed. Those generated by `SpikeSortingRecording` tables were identified as
taking up tens of terabytes of space, while very seldom accessed after their
first generation. By finding a way to recompute these files on demand, we can
save significant server space at the cost of an unlikely 10m of recompute time
per file.

Spyglass 0.5.5 introduces the opportunity to delete and recompute both newly
generated files after this release, and old files that were generated before
this release.

## How

`SpikeSortingRecording` has a new `_make_file` method that will be called in the
event a file is accessed but not found. This method will generate the file and
compare it's hash to the hash of the file that was expected. If the hashes
match, the file will be saved and returned. If the hashes do not match, the file
will be deleted and an error raised. For steps to avoid such errors, see the
steps below.

### New files

Newly generated files will automatically record information about their
dependencies and the code that generated them in `RecomputeSelection` tables. To
see the dependencies of a file, you can access `RecordingRecomputeSelection`

```python
from spyglass.spikesorting.v1 import recompute as v1_recompute

v1_recompute.RecordingRecomputeSelection()
```

### Old files

To ensure the replicability of old files prior to deletion, we'll need to...

1. Update the tables for new fields.
2. Attempt file recompute, and record dependency info for successful attempts.

<!-- TODO: add code snippet. 2 or 3 tables?? -->

```python
from spyglass.spikesorting.v0 import spikesorting_recording as v0_recording
from spyglass.spikesorting.v1 import recording as v1_recording

# Alter tables to include new fields, updating values
v0_recording.SpikeSortingRecording().alter()
v0_recording.SpikeSortingRecording().update_ids()
v1_recording.SpikeSortingRecording().alter()
v1_recording.SpikeSortingRecording().update_ids()
```
8 changes: 4 additions & 4 deletions src/spyglass/common/common_nwbfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def get_hash(
Returns
-------
file_hash : [str, NwbfileHasher]
hash : [str, NwbfileHasher]
The hash of the file contents or the hasher object itself.
"""
hasher = NwbfileHasher(
Expand All @@ -487,7 +487,7 @@ def get_hash(
)
return hasher if return_hasher else hasher.hash

def _update_external(self, analysis_file_name: str, file_hash: str):
def _update_external(self, analysis_file_name: str, hash: str):
"""Update the external contents checksum for an analysis file.
USE WITH CAUTION. If the hash does not match the file contents, the file
Expand All @@ -497,15 +497,15 @@ def _update_external(self, analysis_file_name: str, file_hash: str):
----------
analysis_file_name : str
The name of the analysis NWB file.
file_hash : str
hash : str
The hash of the file contents as calculated by NwbfileHasher.
If the hash does not match the file contents, the file and
downstream entries are deleted.
"""
file_path = self.get_abs_path(analysis_file_name, from_schema=True)
new_hash = self.get_hash(analysis_file_name, from_schema=True)

if file_hash != new_hash:
if hash != new_hash:
Path(file_path).unlink() # remove mismatched file
# force delete, including all downstream, forcing permissions
del_kwargs = dict(force_permission=True, safemode=False)
Expand Down
39 changes: 39 additions & 0 deletions src/spyglass/spikesorting/v0/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
SpikeSortingPipelineParameters,
spikesorting_pipeline_populator,
)
from spyglass.spikesorting.v0.spikesorting_recompute import ( # noqa: F401
RecordingRecompute,
RecordingRecomputeSelection,
)
from spyglass.spikesorting.v0.spikesorting_recording import ( # noqa: F401
SortGroup,
SortInterval,
Expand All @@ -42,3 +46,38 @@
SpikeSorting,
SpikeSortingSelection,
)

__all__ = [
"ArtifactDetection",
"ArtifactDetectionParameters",
"ArtifactDetectionSelection",
"ArtifactRemovedIntervalList",
"AutomaticCuration",
"AutomaticCurationParameters",
"AutomaticCurationSelection",
"CuratedSpikeSorting",
"CuratedSpikeSortingSelection",
"Curation",
"CurationFigurl",
"CurationFigurlSelection",
"MetricParameters",
"MetricSelection",
"QualityMetrics",
"RecordingRecompute",
"RecordingRecomputeSelection",
"SortGroup",
"SortInterval",
"SortingviewWorkspace",
"SortingviewWorkspaceSelection",
"SpikeSorterParameters",
"SpikeSorting",
"SpikeSortingPipelineParameters",
"SpikeSortingPreprocessingParameters",
"SpikeSortingRecording",
"SpikeSortingRecordingSelection",
"SpikeSortingSelection",
"WaveformParameters",
"WaveformSelection",
"Waveforms",
"spikesorting_pipeline_populator",
]
148 changes: 148 additions & 0 deletions src/spyglass/spikesorting/v0/spikesorting_recompute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""This schema is used to track recompute capabilities for existing files."""

from functools import cached_property
from os import environ as os_environ

import datajoint as dj
from numpy import __version__ as np_version
from probeinterface import __version__ as pi_version
from spikeinterface import __version__ as si_version

from spyglass.spikesorting.v0.spikesorting_recording import (
SpikeSortingRecording,
) # noqa F401
from spyglass.utils import logger
from spyglass.utils.nwb_hash import DirectoryHasher

schema = dj.schema("cbroz_temp_v0")


@schema
class RecordingRecomputeSelection(dj.Manual):
definition = """
-> SpikeSortingRecording
---
logged_at_creation=0: bool
pip_deps: blob # dict of pip dependencies
"""

@cached_property
def default_attempt_id(self):
user = dj.config["database.user"]
conda = os_environ.get("CONDA_DEFAULT_ENV", "base")
return f"{user}_{conda}"

@cached_property
def pip_deps(self):
return dict(
spikeinterface=si_version,
probeinterface=pi_version,
numpy=np_version,
)

def key_pk(self, key):
return {k: key[k] for k in self.primary_key}

def insert(self, rows, at_creation=False, **kwargs):
"""Custom insert to ensure dependencies are added to each row."""
if not isinstance(rows, list):
rows = [rows]
if not isinstance(rows[0], dict):
raise ValueError("Rows must be a list of dicts")

inserts = []
for row in rows:
key_pk = self.key_pk(row)
inserts.append(
dict(
**key_pk,
attempt_id=row.get("attempt_id", self.default_attempt_id),
dependencies=self.pip_deps,
logged_at_creation=at_creation,
)
)
super().insert(inserts, **kwargs)

# --- Gatekeep recompute attempts ---

@cached_property
def this_env(self):
"""Restricted table matching pynwb env and pip env.
Serves as key_source for RecordingRecompute. Ensures that recompute
attempts are only made when the pynwb and pip environments match the
records. Also skips files whose environment was logged on creation.
"""

restr = []
for key in self & "logged_at_creation=0":
if key["dependencies"] != self.pip_deps:
continue
restr.append(self.key_pk(key))
return self & restr

def _has_matching_pip(self, key, show_err=True) -> bool:
"""Check current env for matching pip versions."""
query = self.this_env & key

if not len(query) == 1:
raise ValueError(f"Query returned {len(query)} entries: {query}")

need = query.fetch1("dependencies")
ret = need == self.pip_deps

if not ret and show_err:
logger.error(
f"Pip version mismatch. Skipping key: {self.key_pk(key)}"
+ f"\n\tHave: {self.pip_deps}"
+ f"\n\tNeed: {need}"
)

return ret


@schema
class RecordingRecompute(dj.Computed):
definition = """
-> RecordingRecomputeSelection
---
matched:bool
"""

_hasher_cache = dict()

class Name(dj.Part):
definition = """ # File names missing from old or new versions
-> master
name: varchar(255)
missing_from: enum('old', 'new')
"""

class Hash(dj.Part):
definition = """ # File hashes that differ between old and new versions
-> master
name : varchar(255)
"""

def _parent_key(self, key):
ret = SpikeSortingRecording * RecordingRecomputeSelection & key
if len(ret) != 1:
raise ValueError(f"Query returned {len(ret)} entries: {ret}")
return ret.fetch(as_dict=True)[0]

def _hash_one(self, key):
key_hash = dj.hash.key_hash(key)
if key_hash in self._hasher_cache:
return self._hasher_cache[key_hash]
hasher = DirectoryHasher(
path=self._parent_key(key)["recording_path"],
keep_file_hash=True,
)
self._hasher_cache[key_hash] = hasher
return hasher

def make(self, key):
pass

def delete_file(self, key):
pass # TODO: Add means of deleting repliacted files
Loading

0 comments on commit 9c27d87

Please sign in to comment.