Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recompute method in computed tables #917

Open
lfrank opened this issue Apr 5, 2024 · 5 comments
Open

Recompute method in computed tables #917

lfrank opened this issue Apr 5, 2024 · 5 comments
Labels
enhancement New feature or request infrastructure Unix, MySQL, etc. settings/issues impacting users

Comments

@lfrank
Copy link
Contributor

lfrank commented Apr 5, 2024

One of our main goals is to be able to share complete pipelines with results. Sharing of the various intermediate computed data outputs is one way to do this, but for data that can be computed relatively quickly, we could also enable remote users to recompute on the fly.

One possible solution would be to add a recompute method to each dj.Computed table that would regenerate the NWB file (using the same name) if it was not present locally. This would be quite a lot of work, and would also require a lot of thought as to what to do when upstream computed results are not available, but if we could get this to work we could share a much smaller subet of results when we publish papers, which would likely help.

@samuelbray32
Copy link
Collaborator

Thoughts on a potential structure:

  • Recompute is essentially recalling the make function outside of a datajoint transaction and avoiding any insert statements therin.
    • We could add a recompute=False argument to make functions and change any insert statements in them to only run if this is False.
  • Files would only need recomputed when accessed. This could be a new fallback in fetch_nwb that would call Table.make(key, recompute=True) when the analysis nwb can't be obtained another way.
    • Benefit is this would recursively handle propagating recompute up to other missing intermediate tables if they are needed in the the original recompute call since the upstream data should be accessed through fetch_nwb

I'm sure there's some edge cases in things like spikesorting I'm not thinking of, but this might handle a lot without too much change to the code

@edeno edeno added enhancement New feature or request infrastructure Unix, MySQL, etc. settings/issues impacting users labels Apr 19, 2024
@CBroz1
Copy link
Member

CBroz1 commented Aug 28, 2024

We previously added then removed logging of file size and creation time to the AnalysisNwbfileLog table. I put together the following script to look at means for good recompute candidates

Script
from pathlib import Path

import pandas as pd
from datajoint.utils import to_camel_case
from hurry.filesize import size  # REQUIRES: pip install hurry.filesize

from spyglass.common.common_nwbfile import AnalysisNwbfileLog

DATA_PATH = Path("data.pkl")
DATA_BCK = Path("data_bck.pkl")


class LA:
    def __init__(self, fetch=False):
        self.data = (
            AnalysisNwbfileLog().fetch(format="frame")
            if not DATA_BCK.exists() or fetch
            else pd.read_pickle(DATA_BCK)
        )
        self.data.to_pickle(DATA_BCK)
        self._grouped = None

    def load_from_backup(self):
        self.data = pd.read_pickle(DATA_BCK)
        self.reformat()
        self._grouped = None

    def drop_if_cols(self, data, cols):
        cols = [col for col in cols if col in data.columns]
        if not cols:  # no columns to drop
            return data
        data = data.drop(columns=cols, axis=1)
        return data

    def reformat(self):

        def to_tbl_name(full_name):
            if "." not in full_name:
                return full_name, full_name
            schema, table = full_name.replace("`", "").split(".")
            return schema, to_camel_case(table)

        self.data["full_table_name"] = self.data["table"]
        self.data["schema"] = self.data["full_table_name"].apply(
            lambda x: to_tbl_name(x)[0] if x is not None else None
        )
        self.data["table"] = self.data["full_table_name"].apply(
            lambda x: to_tbl_name(x)[1] if x is not None else None
        )

        self.data = self.drop_if_cols(
            self.data, ["analysis_file_name", "full_table_name"]
        )

        numeric_cols = ["time_delta", "file_size", "accessed"]
        self.data[numeric_cols] = self.data[numeric_cols].apply(pd.to_numeric)

        self.data.to_pickle(DATA_PATH)

    @property
    def grouped(self):
        if self._grouped is not None:
            return self._grouped
        self.reformat()
        grouped = self.drop_if_cols(
            self.data, ["dj_user", "timestamp"]
        ).groupby(["schema", "table"])

        mean_df = grouped.mean()
        mean_df = mean_df[mean_df["time_delta"].notnull()]
        sorted_df = mean_df.sort_values("file_size", ascending=False)

        def sec_to_min(sec):
            min = round(sec / 60, 2)
            return f"{min} min"

        def adj_accessed(accessed):
            """Adjust accessed count to be more readable, fix indexing."""
            return round(accessed + 1, 2)

        sorted_df["time_delta"] = sorted_df["time_delta"].apply(sec_to_min)
        sorted_df["file_size"] = sorted_df["file_size"].apply(size)
        sorted_df["accessed"] = sorted_df["accessed"].apply(adj_accessed)

        self._grouped = sorted_df
        return self._grouped


if __name__ == "__main__":
    la = LA(fetch=True)
    print(la.grouped)

With the following results. These are means, including time to create, file size and number of times accessed, including file creation, ordered by file size.

                                                      time_delta file_size  accessed
schema                          table
spikesorting_v1_recording       SpikeSortingRecording   9.85 min        6G      1.00
spikesorting_v1_metric_curation MetricCuration           8.3 min      888M      3.28
lfp_v1                          LFPV1                  12.41 min      108M     21.60
lfp_band_v1                     LFPBandV1               0.19 min       92M     24.39
spikesorting_curation           Waveforms               0.52 min       28M      1.00
position_linearization_v1       LinearizedPositionV1    0.17 min       26M     20.21
position_v1_trodes_position     TrodesPosV1             0.11 min       23M     19.69
position_v1_dlc_pose_estimation DLCPoseEstimation       0.12 min       14M      7.58
position_v1_dlc_position        DLCSmoothInterp         0.88 min        8M      4.91
spikesorting_v1_sorting         SpikeSorting            5.18 min        7M      7.51
spikesorting_v1_curation        CurationV1              0.09 min        6M      5.03
position_v1_dlc_centroid        DLCCentroid              0.1 min        4M      2.34
position_v1_dlc_selection       DLCPosV1                0.13 min        3M      4.50
decoding_waveform_features      UnitWaveformFeatures    6.31 min        2M      6.43
spikesorting_curation           CuratedSpikeSorting     0.08 min      878K      6.60
position_v1_dlc_orient          DLCOrientation          0.07 min      876K      2.02
spikesorting_curation           QualityMetrics          0.67 min      468K      1.00

All conclusions will assume we have a representative sample, which may not be the case. I'll also assume we want to recompute files that are seldom re-accessed.

The following tables produced files that were seldom re-accessed:
v1 SpikeSortingRecording, v0 Waveforms, and v0 QualityMetrics. Only two of 700+ files SpikeSortingRecording files were re-accessed, once and twice, respectively.

  • If we're willing to tolerate a 10m recompute time, focusing on SpikeSortingRecording will let us clear out an average of 6gb per file across 700+ cases (total 4T).
  • If we want to keep recompute time < 1m, Waveforms is a better candidate, but will only save 200M in each case (total 72G, 1.5% of the former case)

@edeno - Are both these operations deterministic?

@samuelbray32
Copy link
Collaborator

I would also put 'spikesorting_recording.__spike_sorting_recording' (the v0 version) in the priority list. It didn't show up in the logging since it's not stored in an AnalysisNwbfile, but it should be roughly the same size and access rates as spikesorting_v1_recording

@CBroz1
Copy link
Member

CBroz1 commented Nov 19, 2024

Updates

I've have a working version of a hasher in #1093. Ideally, we could regenerate some test files and start compiling a list of the files that match. Unfortunately, none of my randomly selected files from SpikeSortingV1 have matched so far, due to small differences in saved dependency versions or larger differences like mismatched data. My working theory is that changing around the dependencies in my recompute environment will resolve these differences, but this requires more testing.

Small differences from my test files include...

  • Changes in git hash of spyglass version saved as source script
  • Changes in hdmf/pynwb version saved as part of object names
  • Changes to docstrings for nwb objects

By censoring these values before hashing, I cut down on mismatches (see remove_version in the code below), but pynwb version also impacts other things like whether or not data type is saved, or object reference names (HERD vs ExternalResources). There were also datasets that appeared to be off by 1e-13 microvolts across the dataset, or places where typos in Spyglass have since been corrected

File Compare Tool

See files in /stelmo/nwb/analysis/ vs /stelmo/cbroz/temp_rcp/

"""
Usage:
> old = "/stelmo/nwb/analysis/example/example_RAND.nwb"
> new = "/stelmo/cbroz/temp_rcp/example/example_RAND.nwb"
> comp = NwbfileComparator(old,new)
> comp.name_mismatch # see names missing in one or the other
> comp.obj_mismatch # see differing objects
> comp.obj('optional_obj_name') # see diffs in scalar data
"""

import atexit
import re
import warnings
from difflib import SequenceMatcher
from hashlib import md5
from pathlib import Path
from pprint import pprint
from typing import Any, Union

import datajoint as dj
import h5py
import numpy as np
from datajoint.logging import logger as dj_logger

warnings.filterwarnings("ignore", module="hdmf")
warnings.filterwarnings("ignore", module="pynwb")

schema = dj.schema("cbroz_temp")

dj_logger.setLevel("INFO")

DEFAULT_BATCH_SIZE = 4095


class NwbfileComparator:
    def __init__(
        self,
        old: Union[str, Path],
        new: Union[str, Path],
        batch_size: int = DEFAULT_BATCH_SIZE,
        verbose: bool = True,
    ):
        """Compares NWB files by pairwise hashing objects.

        Parameters
        ----------
        path : Union[str, Path]
            Path to the NWB file.
        batch_size  : int, optional
            Limit of data to hash for large datasets, by default 4095.
        verbose : bool, optional
            Display progress bar, by default True.
        """
        if not Path(old).exists():
            raise FileNotFoundError(f"File not found: {old}")
        if not Path(new).exists():
            raise FileNotFoundError(f"File not found: {new}")

        self.old = h5py.File(old, "r")
        self.new = h5py.File(new, "r")
        atexit.register(self.cleanup)

        self.batch_size = batch_size
        self.verbose = verbose
        self.name_mismatch = []
        self.hash_mismatch = []
        self.all_old, self.all_new = [], []
        self.obj_mismatch = dict()

        self.status = "which"

        _ = self.compare_files()

        self._obj_mismatch_iter = iter(self.obj_mismatch.items())
        self.comps = zip(self.all_old, self.all_new)

        if not self.obj_mismatch:  # Only close if no mismatches
            self.cleanup()
        atexit.unregister(self.cleanup)

    def remove_version(self, content):
        version_pattern = (
            r"\d+\.\d+\.\d+"  # Major.Minor.Patch
            + r"(?:-alpha|-beta|a\d+)?"  # Optional alpha or beta, -alpha
            + r"(?:\.dev\d{2})?"  # Optional dev build, .dev01
            + r"(?:\+[a-z0-9]{9})?"  # Optional commit hash, +abcdefghi
            + r"(?:\.d\d{8})?"  # Optional date, dYYYYMMDD
        )
        no_ver = re.sub(version_pattern, "VERSION", content)
        docstring_pattern = r'"doc":"(.*?)"'
        ret = re.sub(docstring_pattern, '"doc":"DOCSTRING"', no_ver)
        return ret

    @property
    def mismatches_diff(self):
        ret = []
        for k, (old, new) in self.obj_mismatch.items():
            ret.append(f"Object: {k}")
            if getattr(old, "shape", None) == ():
                old_str = self.remove_version(str(old[()]))
                new_str = self.remove_version(str(new[()]))
                ret.append(self.diff_strings(old_str, new_str, context=15))
            elif isinstance(old, h5py.Dataset):
                ret.append(str(old[:5]))
                ret.append(str(new[:5]))
            ret.append(" ")
        return "\n".join([r for r in ret if r is not None])

    def cleanup(self):
        self.old.close()
        self.new.close()

    def compare_files(self):
        old_items = self.collect_names(self.old)
        new_items = self.collect_names(self.new)

        all_names = set(old_items.keys()) | set(new_items.keys())

        for name in all_names:
            if name not in old_items:
                self.name_mismatch.append({"name": name, "missing_from": "old"})
                continue
            if name not in new_items:
                self.name_mismatch.append({"name": name, "missing_from": "new"})
                continue
            self.status = "old"
            old_hash = self.compute_hash(name, old_items[name])
            self.status = "new"
            new_hash = self.compute_hash(name, new_items[name])

            if old_hash != new_hash:
                self.hash_mismatch.append({"name": name})
                self.obj_mismatch[name] = (old_items[name], new_items[name])

    def collect_names(self, file):
        """Collects all object names in the file."""

        def collect_items(name, obj):
            name = self.remove_version(name)
            if name in items_to_process:
                raise ValueError(f"Duplicate key: {name}")
            items_to_process.update({name: obj})

        items_to_process = dict()
        file.visititems(collect_items)
        return items_to_process

    def serialize_attr_value(self, value: Any):
        """Serializes an attribute value into bytes for hashing.

        Setting all numpy array types to string avoids false positives.

        Parameters
        ----------
        value : Any
            Attribute value.

        Returns
        -------
        bytes
            Serialized bytes of the attribute value.
        """
        if isinstance(value, np.ndarray):
            return value.astype(str).tobytes()  # must be 'astype(str)'
        elif isinstance(value, (str, int, float)):
            return self.remove_version(str(value)).encode()
        return self.remove_version(repr(value)).encode()

    def hash_dataset(self, dataset: h5py.Dataset):
        hashed = md5(self.hash_shape_dtype(dataset))

        if dataset.shape == ():
            hashed.update(self.serialize_attr_value(dataset[()]))
            return hashed.hexdigest().encode()

        # WARNING: only head of data
        size = min(dataset.shape[0], self.batch_size * 5)
        # size = dataset.shape[0]
        start = 0
        padding = len(str(size))

        while start < size:
            pad_start = f"{round(start,padding-2):0{padding}}"
            print(f"\rData: {dataset.name}: {pad_start}/{size}", end="")
            end = min(start + self.batch_size, size)
            hashed.update(self.serialize_attr_value(dataset[start:end]))
            start = end

        print()
        return hashed.hexdigest().encode()

    def hash_shape_dtype(self, obj: [h5py.Dataset, np.ndarray]) -> str:
        if not hasattr(obj, "shape") or not hasattr(obj, "dtype"):
            return "".encode()
        return str(obj.shape).encode() + str(obj.dtype).encode()

    def compute_hash(self, name, obj) -> str:
        hashed = md5(name.encode())

        for attr_key in sorted(obj.attrs):
            attr_value = obj.attrs[attr_key]
            hashed = self.uhash(hashed, self.hash_shape_dtype(attr_value))
            hashed = self.uhash(hashed, attr_key.encode())
            hashed = self.uhash(hashed, self.serialize_attr_value(attr_value))

        if isinstance(obj, h5py.Dataset):
            hashed = self.uhash(hashed, self.hash_dataset(obj))
        elif isinstance(obj, h5py.SoftLink):
            hashed = self.uhash(hashed, obj.path.encode())
        elif isinstance(obj, h5py.Group):
            for k, v in obj.items():
                hashed = self.uhash(hashed, self.remove_version(k).encode())
                hashed = self.uhash(hashed, self.serialize_attr_value(v))
        else:
            raise TypeError(
                f"Unknown object type: {type(obj)}\n"
                + "Please report this an issue on GitHub."
            )

        return hashed.hexdigest()

    def uhash(self, hash, value):
        hash.update(value)
        if self.status == "old":
            self.all_old.append(value)
        elif self.status == "new":
            self.all_new.append(value)
        return hash

    def comp_obj(self, obj=None):
        """Show string diffs for given object. 

        If obj=None, iterate over differing objects.
        """
        if obj is not None and obj in self.obj_mismatch:
            key = obj
            old, new = self.obj_mismatch[key]
        else:
            try:
                key, (old, new) = next(self._obj_mismatch_iter)
            except StopIteration:
                return None
        print(f"Object: {key}")
        if getattr(old, "shape", None) == ():
            old = self.remove_version(str(old[()]))
            new = self.remove_version(str(new[()]))
            pprint(self.diff_strings(old, new, context=15))
        return old, new

    def diff_strings(self, a: str, b: str, context=30) -> str:
        """Highlight differences between two strings with surrounding context."""
        a = str(a)
        b = str(b)

        matcher = SequenceMatcher(None, a, b)
        diffs = []
        for tag, i1, i2, j1, j2 in matcher.get_opcodes():
            if tag != "equal":
                diffs.append(
                    f"...{a[max(0, i1-context):i2+context]}"
                    + f" -> {b[max(0, j1-context):j2+context]}..."
                )
        return "\n".join(diffs)

Questions

From a detail perspective, should a two files have the same hash if they differ in these ways?

  1. Saved version
  2. NWB docstrings
  3. Spyglass typos - how do we maintain records of these cases?
  4. Data - what is the appropriate amount of rounding for electrical series? other data?

Each case where we make changes to the input during hashing is additional processing time hashing the recompute product. This is especially true for rounding big datasets

Bigger picture, what is our process for a mismatch on recompute? Hypothetically, a paper has been submitted using a downstream analysis and a reviewer asks for summary stats that would require a recomputed file, whose new hash does not match

  1. Redo: We delete everything downstream of now-unknown provenance, updating all figures with (probably) minor differences in data.
  • Pro: The updated version becomes the new 'official' version
  • Con: Lots of work to update downstream analyses and figures late in the submission process
  1. Asterisk: We accept that intermediate files cannot be kept in perpetuity and use the replicated version as 'close enough'.
  • Pro: Reduce recompute burden
  • Con: Potential concerns about replicability
  1. Resolve: We fully document conda environment prior to deletion and (either always or only in the case of mismatches) spin up that environment to replicate the file.
  • Pro: Hopefully, full replicabilty
  • Con:
    • Long wait times for recomputing files
    • Potentially not feasible for under-documented existing files, removing our ability to delete them
    • A lot of overhead infrastructure/dev time to implement env spin-up systematically

Certainly, we can start default to one approach and take another on a case-by-case basis.

Next steps

Next, I'll make an effort to reverse-engineer the required environment from existing files to test our capacity to take the 'resolve' route for existing files

@CBroz1
Copy link
Member

CBroz1 commented Dec 4, 2024

Replication

Environment

I've had trouble replicating files without a complete record of the original
conda environment.

The files provide (a) a pynwb version, and (b) a spyglass git hash.
Because pynwb always pins hdmf and h5py, we can use this version to set
those dependencies. Setting spyglass is trickier, as the replication feature
is targeted for a future release. Backporting is possible, but a large tech
burden.

Unfortunately, most of the existing files use pynwb==2.6.0-alpha, for which
there is no record of dependency pins. I've made a best guess based on
one user's environment, but there may be variations, as the pins were changed
in the subsequent release. For replicating, I'll use the versions in the table
below with an up to date Spyglass.

File counts by version, and dependency pins
# Existing files by Spyglass and PyNWB versions
spy_ver      0.1  0.4.0  0.4.3  0.5.0  0.5.2  0.5.4  Total
pynwb_ver
2.5.0          0    204      0      0      0      0    204
2.6.0-alpha   42      0    658      1    220      0    921
2.7.0          0      0      0      0     96    171    267
Total         42    204    658      1    316    171   1392

# PyNWB version pins, with *best guess
pynwb version   hdmf       h5py
+------------+ +--------+ +--------+
2.5.0          3.9.0      3.8.0
2.6.0-alpha*   3.11.0     3.10.0
2.6.0          3.12.2     3.10.0

Censoring

In an effort to minimize mismatches, I've added the option to censor version
numbers and docstrings from both object names and scalar datasets prior to
hashing. This prevents arbitrary docstring changes from causing a mismatch, but
also treats these fields as strings, rather than nested dictionaries, to speed
up the process. There have been some cases where ExternalResources or Group
objects have had different structures across old and replicated files.

Mismatches

Missing Objects

A few old files are missing general/source_script objects. While the contents
is censored, a missing object still impacts the hash. We could

  1. Ignore: Decide that this field should not be hashed, adding to a list of
    exceptions that becomes hard to maintain.
  2. Retrofit: Add a default value to all existing files where it is missing. This
    would require a careful review of when items were updated, and dependency
    versions to use when replicating.

Hash mismatches

~80% of generated files have some hashing mismatch of at least one object.

  • specifications/core/VERSION/nwb.icephys
  • specifications/core/VERSION/nwb.ophys
  • general (missing source_script)
  • specifications/hdmf-experimental/VERSION/resources
  • acquisition/ProcessedElectricalSeries/data

Some of these are actual datasets, while others are groups or scalar datasets,
that are initially read as strings. Scalar datasets that can be further
unpacked into nested dicts/lists using eval methods, which present a security
risk, and should be handled carefully. While we can be confident in data we
generate, we may need to adjust the process for imported data.

def unpack_scalar(obj):
    return eval(
        eval(
            str(obj[()])
            .replace("null", '"null"')
            .replace("false", '"false"')[1:]
        )
    )

Icephys

Old nwb.icephys objects are missing value key/value pairs for
'bias_currens', in group datasets. As above, we can either attempt to ignore
these, or retrofit existing files. Using SC7920240910_B1PER4NSHE.nwb as an
example.

old['groups'][2]['datasets'][0].get('value') == None
new['groups'][2]['datasets'][0].get('value') == 0.0

Ophys

nwb.ophys objects have an additional group with links to an ImagingPlane
object. See j1620210710_QLQKV5ZTOQ.nwb as an example. Again, we can find
a way to ignore or retrofit these objects.

resources

Even controlling for hdmf version, resources objects have mismatching groups
in ~50% of files, especially for those generated with pynwb==2.5.0. Using
Lewis20240222_S3M69NGSAK.nwb as an example, the old file has groups for
entity_keys missing in the new, and the new file has a group for resources
missing in the old. The objects key also differs in structure, with one the
old file having 5 datatypes, versus 3 in the new.

ProcessedElectricalSeries

Many ProcessedElectricalSeries datasets have mismatching hashes, but will
pass a np.isclose test, with an average difference of less than 1e-16. With
the old data in hand, we could be more confident in the replication, but
adjusting the hash to account for small differences would require significant
increases to the hash time for large datasets.

Questions

Should the hasher take the time to...

  • Ignore missing general/source_script objects?
  • Unpack scalar datasets into nested dictionaries? (Security risk)
  • Round ProcessedElectricalSeries datasets to a certain precision?

Should we adjust the existing files before hashing to account for the updated
specifications? Are we then tying ourselves to the current spec, and would need to run a file surgery overhaul whenever this is updated?

  • source_script
  • icephys value
  • ophys ImagingPlane

A growing list of exceptions also increases the risk of a false negative, if future file specs introduce a meaningful difference in some region we've decided to ignore.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request infrastructure Unix, MySQL, etc. settings/issues impacting users
Projects
None yet
Development

No branches or pull requests

4 participants