From 55911395662dd31445ea31160bc5e0b3340be2db Mon Sep 17 00:00:00 2001 From: Ryan Harvey Date: Thu, 17 Oct 2024 09:46:11 -0400 Subject: [PATCH] docs, typing, spikes module --- neuro_py/spikes/spike_tools.py | 110 ++++++++++++++++++++++++--------- 1 file changed, 80 insertions(+), 30 deletions(-) diff --git a/neuro_py/spikes/spike_tools.py b/neuro_py/spikes/spike_tools.py index 8c4e7f8..4f80a44 100644 --- a/neuro_py/spikes/spike_tools.py +++ b/neuro_py/spikes/spike_tools.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import List, Union import numpy as np import pandas as pd @@ -6,17 +6,22 @@ def get_spindices(data: np.ndarray) -> pd.DataFrame: """ - Get spike timestamps and spike id for each spike train in a - sorted dataframe of spike trains + Get spike timestamps and spike IDs from each spike train in a sorted DataFrame. + Parameters ---------- data : np.ndarray - spike times for each spike train, in a list of arrays + Spike times for each spike train, where each element is an array of spike times. + Returns ------- - spikes : pd.DataFrame - sorted dataframe of spike times and spike id + pd.DataFrame + A DataFrame containing sorted spike times and the corresponding spike IDs. + Example + ------- + spike_trains = [np.array([0.1, 0.2, 0.4]), np.array([0.15, 0.35])] + spikes = get_spindices(spike_trains) """ spikes_id = [] for spk_i, spk in enumerate(data): @@ -30,20 +35,27 @@ def get_spindices(data: np.ndarray) -> pd.DataFrame: def spindices_to_ndarray( - spikes: pd.DataFrame, spike_id: Union[list, np.ndarray, None] = None -) -> np.ndarray: + spikes: pd.DataFrame, spike_id: Union[List[int], np.ndarray, None] = None +) -> List[np.ndarray]: """ - Convert spike times and spike id to a list of arrays + Convert spike times and spike IDs from a DataFrame into a list of arrays, + where each array contains the spike times for a given spike train. + Parameters ---------- spikes : pd.DataFrame - sorted dataframe of spike times and spike id - spike_id: list or np.ndarray - spike ids search for in the dataframe (important if spikes were restricted) + DataFrame containing 'spike_times' and 'spike_id' columns, sorted by 'spike_times'. + spike_id : list or np.ndarray, optional + List or array of spike IDs to search for in the DataFrame. If None, all spike IDs are used. + Returns ------- - data : np.ndarray - spike times for each spike train, in a list of arrays + List[np.ndarray] + A list of arrays, each containing the spike times for a corresponding spike train. + + Example + ------- + spike_trains = spindices_to_ndarray(spikes_df, spike_id=[0, 1, 2]) """ if spike_id is None: spike_id = np.unique(spikes["spike_id"]) @@ -53,11 +65,34 @@ def spindices_to_ndarray( return data -def BurstIndex_Royer_2012(autocorrs): - # calc burst index from royer 2012 - # burst_idx will range from -1 to 1 - # -1 being non-bursty and 1 being bursty +def BurstIndex_Royer_2012(autocorrs: pd.DataFrame) -> list: + """ + Calculate the burst index from Royer et al. (2012). + The burst index ranges from -1 to 1, where: + -1 indicates non-bursty behavior, and 1 indicates bursty behavior. + Parameters + ---------- + autocorrs : pd.DataFrame + Autocorrelograms of spike trains, with time (in seconds) as index and correlation values as columns. + + Returns + ------- + list + List of burst indices for each autocorrelogram column. + + Notes + ----- + The burst index is calculated as: + burst_idx = (peak - baseline) / max(peak, baseline) + + - Peak is calculated as the maximum value of the autocorrelogram between 2-9 ms. + - Baseline is calculated as the mean value of the autocorrelogram between 40-50 ms. + + Example + ------- + burst_idx = BurstIndex_Royer_2012(autocorr_df) + """ # peak range 2 - 9 ms peak = autocorrs.loc[0.002:0.009].max() # baseline idx 40 - 50 ms @@ -78,19 +113,34 @@ def BurstIndex_Royer_2012(autocorrs): return burst_idx -def select_burst_spikes(spikes, mode="bursts", isiBursts=0.006, isiSpikes=0.020): +def select_burst_spikes( + spikes: np.ndarray, + mode: str = "bursts", + isiBursts: float = 0.006, + isiSpikes: float = 0.020, +) -> np.ndarray: """ - select_burst_spikes - Discriminate bursts vs single spikes. - adpated from: http://fmatoolbox.sourceforge.net/Contents/FMAToolbox/Analyses/SelectSpikes.html - - Input: - spikes: list of spike times - mode: either 'bursts' (default) or 'single' - isiBursts: max inter-spike interval for bursts (default = 0.006) - isiSpikes: min for single spikes (default = 0.020) - Output: - selected: a logical vector indicating for each spike whether it - matches the criterion + Discriminate bursts versus single spikes based on inter-spike intervals. + + Parameters + ---------- + spikes : np.ndarray + Array of spike times. + mode : str, optional + Either 'bursts' (default) or 'single'. + isiBursts : float, optional + Maximum inter-spike interval for bursts (default = 0.006 seconds). + isiSpikes : float, optional + Minimum inter-spike interval for single spikes (default = 0.020 seconds). + + Returns + ------- + np.ndarray + A boolean array indicating for each spike whether it matches the criterion. + + Notes + ----- + Adapted from: http://fmatoolbox.sourceforge.net/Contents/FMAToolbox/Analyses/SelectSpikes.html """ dt = np.diff(spikes)