Skip to content

Commit

Permalink
Merge pull request #15 from SpikeInterface/np-ultra
Browse files Browse the repository at this point in the history
Split np-ultra into multiple datasets
  • Loading branch information
alejoe91 authored Sep 29, 2024
2 parents 40f0d7b + c8abce7 commit 0023db2
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 38 deletions.
1 change: 1 addition & 0 deletions python/consolidate_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def consolidate_datasets(dry_run: bool = False, verbose: bool = False):
zarr_datasets = list_zarr_directories(bucket_name=bucket, boto_client=boto_client)
datasets_to_avoid = ["test_templates.zarr"]
zarr_datasets = [d for d in zarr_datasets if d not in datasets_to_avoid]
zarr_datasets = sorted(zarr_datasets)

if not zarr_datasets:
raise FileNotFoundError(f"No Zarr datasets found in bucket: {bucket}")
Expand Down
6 changes: 5 additions & 1 deletion python/delete_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,16 @@ def delete_templates_too_few_spikes(min_spikes=50, dry_run=False, verbose=True):
n_original_units = len(all_unit_indices)
unit_indices_to_keep = np.delete(all_unit_indices, template_indices_to_remove)
n_units_to_keep = len(unit_indices_to_keep)

spikes_per_unit = zarr_root["spikes_per_unit"]
if verbose:
print(f"\tMax spikes to remove: {spikes_per_unit[template_indices_to_remove]}")
print(f"\tRemoving {n_original_units - n_units_to_keep} templates from {n_original_units}")
for dset in datasets_to_filter:
dataset_original = zarr_root[dset]
if len(dataset_original) == n_units_to_keep:
print(f"\t\tDataset: {dset} - shape: {dataset_original.shape} - already updated")
if verbose:
print(f"\t\tDataset: {dset} - shape: {dataset_original.shape} - already updated")
continue
dataset_filtered = dataset_original[unit_indices_to_keep]
if not dry_run:
Expand Down
1 change: 0 additions & 1 deletion python/upload_ibl_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ def find_channels_with_max_peak_to_peak_vectorized(templates):
expected_shape = (number_of_units, number_of_temporal_samples, number_of_channels)
assert templates_extension_data.templates_array.shape == expected_shape

# TODO: skip templates with 0 amplitude!
# TODO: check for weird shapes
templates_extension = analyzer.get_extension("templates")
templates_object = templates_extension.get_data(outputs="Templates")
Expand Down
92 changes: 56 additions & 36 deletions python/upload_npultra_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import pandas as pd
import os
import numcodecs
from tqdm.auto import tqdm

import probeinterface as pi
import spikeinterface as si

Expand All @@ -40,12 +42,13 @@ def smooth_edges(templates, pad_samples, smooth_percent=0.5, smooth_strength=1):

# parameters
min_spikes_per_unit = 50
num_templates_per_dataset = 100
target_nbefore = 90
target_nafter = 150
upload_data = False

npultra_templates_path = Path("/home/alessio/Documents/Data/Templates/NPUltraWaveforms/")
dataset_name = "steinmetz_ye_np_ultra_2022_figshare19493588v2.zarr"
dataset_stem = "steinmetz_ye_np_ultra_2022_figshare19493588v2"

# AWS credentials
aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID")
Expand Down Expand Up @@ -117,42 +120,59 @@ def smooth_edges(templates, pad_samples, smooth_percent=0.5, smooth_strength=1):
templates_ultra = si.Templates(
templates_array=templates_smoothed,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
nbefore=target_nbefore,
unit_ids=unit_ids,
probe=probe,
is_scaled=True,
)

best_channel_index = si.get_template_extremum_channel(templates_ultra, mode="peak_to_peak", outputs="index")
best_channel_index = list(best_channel_index.values())

if upload_data:
# Create a S3 file system object with explicit credentials
s3_kwargs = dict(anon=False, key=aws_access_key_id, secret=aws_secret_access_key, client_kwargs=client_kwargs)
s3 = s3fs.S3FileSystem(**s3_kwargs)

# Specify the S3 bucket and path
s3_path = f"{bucket_name}/{dataset_name}"
store = s3fs.S3Map(root=s3_path, s3=s3)
else:
folder_path = Path.cwd() / "build" / f"{dataset_name}"
folder_path.mkdir(exist_ok=True, parents=True)
store = zarr.DirectoryStore(str(folder_path))

# Save results to Zarr
zarr_group = zarr.group(store=store, overwrite=True)
zarr_group.create_dataset(name="brain_area", data=brain_area, object_codec=numcodecs.VLenUTF8())
zarr_group.create_dataset(name="spikes_per_unit", data=spikes_per_unit, chunks=None, dtype="uint32")
zarr_group.create_dataset(
name="best_channel_index",
data=best_channel_index,
chunks=None,
dtype="uint32",
)
peak_to_peak = np.ptp(templates_array, axis=1)
zarr_group.create_dataset(name="peak_to_peak", data=peak_to_peak)

# Now you can create a Zarr array using this store
templates_ultra.add_templates_to_zarr_group(zarr_group=zarr_group)
zarr_group_s3 = zarr_group
zarr.consolidate_metadata(zarr_group_s3.store)
print(f"Full templates: {templates_ultra}")

split_indices = np.arange(0, len(unit_ids), num_templates_per_dataset)

for i in tqdm(np.arange(len(split_indices)), desc="Uploading dataset in chunks"):
index = split_indices[i]
if i < len(split_indices) - 1:
s = slice(index, split_indices[i + 1])
else:
s = slice(index, len(unit_ids))
unit_ids_split = unit_ids[s]
brain_area_split = brain_area_acronym[s]
spikes_per_unit_split = spikes_per_unit[s]

templates_split = templates_ultra.select_units(unit_ids_split)
print(f"Creating dataset {i} with {len(unit_ids_split)} units")
dataset_name = f"{dataset_stem}_{i}.zarr"

best_channel_index = si.get_template_extremum_channel(templates_split, mode="peak_to_peak", outputs="index")
best_channel_index = list(best_channel_index.values())

if upload_data:
# Create a S3 file system object with explicit credentials
s3_kwargs = dict(anon=False, key=aws_access_key_id, secret=aws_secret_access_key, client_kwargs=client_kwargs)
s3 = s3fs.S3FileSystem(**s3_kwargs)

# Specify the S3 bucket and path
s3_path = f"{bucket_name}/{dataset_name}"
store = s3fs.S3Map(root=s3_path, s3=s3)
else:
folder_path = Path.cwd() / "build" / f"{dataset_name}"
folder_path.mkdir(exist_ok=True, parents=True)
store = zarr.DirectoryStore(str(folder_path))

# Save results to Zarr
zarr_group = zarr.group(store=store, overwrite=True)
zarr_group.create_dataset(name="brain_area", data=brain_area_split, object_codec=numcodecs.VLenUTF8())
zarr_group.create_dataset(name="spikes_per_unit", data=spikes_per_unit_split, chunks=None, dtype="uint32")
zarr_group.create_dataset(
name="best_channel_index",
data=best_channel_index,
chunks=None,
dtype="uint32",
)
peak_to_peak = np.ptp(templates_split.templates_array, axis=1)
zarr_group.create_dataset(name="peak_to_peak", data=peak_to_peak)

# Now you can create a Zarr array using this store
templates_split.add_templates_to_zarr_group(zarr_group=zarr_group)
zarr_group_s3 = zarr_group
zarr.consolidate_metadata(zarr_group_s3.store)

0 comments on commit 0023db2

Please sign in to comment.