Skip to content

Commit

Permalink
Merge pull request #5 from SpikeInterface/update-consolidat
Browse files Browse the repository at this point in the history
Update consolidate script
  • Loading branch information
h-mayorquin authored May 8, 2024
2 parents 2066383 + 74de0f1 commit 646c8df
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 31 deletions.
8 changes: 4 additions & 4 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"@testing-library/react": "^13.4.0",
"@testing-library/user-event": "^13.5.0",
"plotly.js-dist": "^2.29.0",
"react": "^18.2.0",
"react": "^18.3.1",
"react-dom": "^18.2.0",
"react-scripts": "^5.0.1",
"react-syntax-highlighter": "^15.5.0",
Expand Down
69 changes: 50 additions & 19 deletions python/consolidate_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
import pandas as pd
import zarr
import numpy as np
from argparse import ArgumentParser

from spikeinterface.core import Templates

HYBRID_BUCKET = "spikeinterface-template-database"
SKIP_TEST = True
parser = ArgumentParser(description="Consolidate datasets from spikeinterface template database")

parser.add_argument("--dry-run", action="store_true", help="Dry run (no upload)")
parser.add_argument("--no-skip-test", action="store_true", help="Skip test datasets")
parser.add_argument("--bucket", type=str, help="S3 bucket name", default="spikeinterface-template-database")


def list_bucket_objects(
Expand Down Expand Up @@ -43,34 +47,39 @@ def list_bucket_objects(
return keys


def consolidate_datasets():
def consolidate_datasets(
dry_run: bool = False, skip_test_folder: bool = True, bucket="spikeinterface-template-database"
):
### Find datasets and create dataframe with consolidated data
bc = boto3.client("s3")

# Each dataset is stored in a zarr folder, so we look for the .zattrs files
exclude_substrings = ["test_templates"] if SKIP_TEST else None
keys = list_bucket_objects(
HYBRID_BUCKET, boto_client=bc, include_substrings=".zattrs", exclude_substrings=exclude_substrings
)
skip_substrings = ["test_templates"] if skip_test_folder else None
keys = list_bucket_objects(bucket, boto_client=bc, include_substrings=".zattrs", skip_substrings=skip_substrings)
datasets = [k.split("/")[0] for k in keys]
print(f"Found {len(datasets)} datasets to consolidate\n")

templates_df = pd.DataFrame(
columns=["dataset", "template_index", "best_channel_id", "brain_area", "depth", "amplitude"]
)
templates_df = None

# Loop over datasets and extract relevant information
for dataset in datasets:
print(f"Processing dataset {dataset}")
zarr_path = f"s3://{HYBRID_BUCKET}/{dataset}"
zarr_path = f"s3://{bucket}/{dataset}"
zarr_group = zarr.open_consolidated(zarr_path, storage_options=dict(anon=True))

templates = Templates.from_zarr_group(zarr_group)

num_units = templates.num_units
dataset_list = [dataset] * num_units
dataset_path = [zarr_path] * num_units
template_idxs = np.arange(num_units)
best_channel_idxs = zarr_group.get("best_channels", None)
brain_areas = zarr_group.get("brain_area", None)
peak_to_peaks = zarr_group.get("peak_to_peak", None)
spikes_per_units = zarr_group.get("spikes_per_unit", None)

# TODO: get probe name from probe metadata

channel_depths = templates.get_channel_locations()[:, 1]

depths = np.zeros(num_units)
Expand All @@ -80,32 +89,54 @@ def consolidate_datasets():
best_channel_idxs = best_channel_idxs[:]
for i, best_channel_idx in enumerate(best_channel_idxs):
depths[i] = channel_depths[best_channel_idx]
amps[i] = np.ptp(templates.templates_array[i, :, best_channel_idx])
if peak_to_peaks is None:
amps[i] = np.ptp(templates.templates_array[i, :, best_channel_idx])
else:
amps[i] = peak_to_peaks[i, best_channel_idx]
else:
depths = np.nan
amps = np.nan
best_channels = ["unknwown"] * num_units
best_channel_idxs = [-1] * num_units
spikes_per_units = [-1] * num_units
if brain_areas is not None:
brain_areas = brain_areas[:]
else:
brain_areas = ["unknwown"] * num_units

new_entry = pd.DataFrame(
data={
"dataset": dataset_list,
"dataset_path": dataset_path,
"probe": ["Neuropixels1.0"] * num_units,
"template_index": template_idxs,
"best_channel_id": best_channels,
"best_channel_id": best_channel_idxs,
"brain_area": brain_areas,
"depth": depths,
"amplitude": amps,
"depth_along_probe": depths,
"amplitude_uv": amps,
"spikes_per_unit": spikes_per_units,
}
)
templates_df = pd.concat([templates_df, new_entry])
if templates_df is None:
templates_df = new_entry
else:
templates_df = pd.concat([templates_df, new_entry])
print(f"Added {num_units} units from dataset {dataset}")

templates_df.reset_index(inplace=True)
templates_df.to_csv("templates.csv", index=False)

# Upload to S3
bc.upload_file("templates.csv", HYBRID_BUCKET, "templates.csv")
if not dry_run:
bc.upload_file("templates.csv", bucket, "templates.csv")
else:
print("Dry run, not uploading")
print(templates_df)

return templates_df


if __name__ == "__main__":
consolidate_datasets()
params = parser.parse_args()
DRY_RUN = params.dry_run
SKIP_TEST = not params.no_skip_test
templates_df = consolidate_datasets(dry_run=DRY_RUN, skip_test_folder=SKIP_TEST)
8 changes: 4 additions & 4 deletions src/components/App.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import "../styles/App.css";
//const url = "http://localhost:8000/zarr_store.zarr";
const url = "https://spikeinterface-template-database.s3.us-east-2.amazonaws.com/test_templates";

const url = process.env.TEST_URL || "https://s3.amazonaws.com/my-bucket/templates";
//const url = process.env.TEST_URL || "https://s3.amazonaws.com/my-bucket/templates";


function App() {
Expand All @@ -20,7 +20,7 @@ function App() {
const batchSize = 10;
const [dataDictionary, setDataDictionary] = useState({});

const loadTempalteIndices = () => {
const loadTemplateIndices = () => {
const nextIndex = templateIndices.length === 0 ? 0 : Math.max(...templateIndices) + 1;
const newIndices = Array.from({ length: batchSize }, (_, i) => i + nextIndex);

Expand Down Expand Up @@ -82,7 +82,7 @@ function App() {
};

useEffect(() => {
loadTempalteIndices();
loadTemplateIndices();
loadSessionData();
}, []);

Expand All @@ -104,7 +104,7 @@ function App() {
))}
</div>
{hasMore && (
<button onClick={loadTempalteIndices} className="load-more-button">
<button onClick={loadTemplateIndices} className="load-more-button">
Load More Templates
</button>
)}
Expand Down
8 changes: 5 additions & 3 deletions src/components/CodeSnippet.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import { vs } from 'react-syntax-highlighter/dist/esm/styles/prism';
function CodeSnippet({ selectedTemplates }) {
const generatePythonCode = (selectedTemplates) => {
const selectedUnitIndicesString = JSON.stringify([...selectedTemplates], null, 2);
return `from spikeinterface.hybrid import generate_recording_from_template_database
return `from spikeinterface.generation import get_templates_from_database, generate_hybrid_recording
selected_unit_indices = ${selectedUnitIndicesString}
durations = [1.0] # Specify the duration for each template
recording = generate_recording_from_template_database(selected_unit_indices, durations=durations)`;
templates = get_templates_from_database(selected_unit_indices)
# recording is an existing spikeinterface.BaseRecording
recording_hybrid = get_templates_from_database(recording, templates=templates)`;
};

const pythonCode = generatePythonCode(selectedTemplates);
Expand Down

0 comments on commit 646c8df

Please sign in to comment.