Skip to content

Commit

Permalink
Homogeneous Poisson Process with refr. period (#261)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbouss authored and dizcza committed Nov 15, 2019
1 parent f378f07 commit 059ccf5
Show file tree
Hide file tree
Showing 2 changed files with 252 additions and 48 deletions.
150 changes: 132 additions & 18 deletions elephant/spike_train_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@
"""

from __future__ import division

import math
import random
import warnings

import numpy as np
from quantities import ms, mV, Hz, Quantity, dimensionless
from neo import SpikeTrain
import random
from quantities import ms, mV, Hz, Quantity, dimensionless

from elephant.spike_train_surrogates import dither_spike_train
import warnings


def spike_extraction(signal, threshold=0.0 * mV, sign='above',
Expand Down Expand Up @@ -93,15 +97,15 @@ def spike_extraction(signal, threshold=0.0 * mV, sign='above',
borders = np.dstack((borders_left, borders_right)).flatten()

waveforms = np.array(
np.split(np.array(signal), borders.astype(int))[1::2]) * signal.units
np.split(np.array(signal), borders.astype(int))[1::2]) * signal.units

# len(np.shape(waveforms)) == 1 if waveforms do not have the same width.
# this can occur when extr_interval indexes beyond the signal.
# Workaround: delete spikes shorter than the maximum length with
if len(np.shape(waveforms)) == 1:
max_len = (np.array([len(x) for x in waveforms])).max()
to_delete = np.array([idx for idx, x in enumerate(waveforms)
if len(x) < max_len])
if len(x) < max_len])
waveforms = np.delete(waveforms, to_delete, axis=0)
waveforms = np.array([x for x in waveforms])
warnings.warn("Waveforms " +
Expand Down Expand Up @@ -164,7 +168,8 @@ def threshold_detection(signal, threshold=0.0 * mV, sign='above'):
if events_base is None:
# This occurs in some Python 3 builds due to some
# bug in quantities.
events_base = np.array([event.magnitude for event in events]) # Workaround
events_base = np.array(
[event.magnitude for event in events]) # Workaround

result_st = SpikeTrain(events_base, units=signal.times.units,
t_start=signal.t_start, t_stop=signal.t_stop)
Expand Down Expand Up @@ -217,7 +222,7 @@ def peak_detection(signal, threshold=0.0 * mV, sign='above', format=None):
border_end = border_start + 1
borders = np.concatenate((border_start, border_end))
borders = np.append(0, borders)
borders = np.append(borders, len(cutout)-1)
borders = np.append(borders, len(cutout) - 1)
borders = np.sort(borders)
true_borders = cutout[borders]
right_borders = true_borders[1::2] + 1
Expand All @@ -241,7 +246,8 @@ def peak_detection(signal, threshold=0.0 * mV, sign='above', format=None):
if events_base is None:
# This occurs in some Python 3 builds due to some
# bug in quantities.
events_base = np.array([event.magnitude for event in events]) # Workaround
events_base = np.array(
[event.magnitude for event in events]) # Workaround
if format is None:
result_st = SpikeTrain(events_base, units=signal.times.units,
t_start=signal.t_start, t_stop=signal.t_stop)
Expand All @@ -260,6 +266,7 @@ def _homogeneous_process(interval_generator, args, mean_rate, t_start, t_stop,
generated by the function `interval_generator` with the given rate,
starting at time `t_start` and stopping `time t_stop`.
"""

def rescale(x):
return (x / mean_rate.units).rescale(t_stop.units)

Expand Down Expand Up @@ -346,8 +353,8 @@ def inhomogeneous_poisson_process(rate, as_array=False):
Parameters
----------
rate : neo.AnalogSignal
A `neo.AnalogSignal` representing the rate profile evolving over time.
Its values have all to be `>=0`. The output spiketrain will have
A `neo.AnalogSignal` representing the rate profile evolving over time.
Its values have all to be `>=0`. The output spiketrain will have
`t_start = rate.t_start` and `t_stop = rate.t_stop`
as_array : bool
If True, a NumPy array of sorted spikes is returned,
Expand All @@ -362,14 +369,14 @@ def inhomogeneous_poisson_process(rate, as_array=False):
'rate must be a positive non empty signal, representing the'
'rate at time t')
else:
#Generate n hidden Poisson SpikeTrains with rate equal to the peak rate
# Generate n hidden Poisson SpikeTrains with rate equal to the peak rate
max_rate = np.max(rate)
homogeneous_poiss = homogeneous_poisson_process(
rate=max_rate, t_stop=rate.t_stop, t_start=rate.t_start)
# Compute the rate profile at each spike time by interpolation
rate_interpolated = _analog_signal_linear_interp(
signal=rate, times=homogeneous_poiss.magnitude *
homogeneous_poiss.units)
homogeneous_poiss.units)
# Accept each spike at time t with probability rate(t)/max_rate
u = np.random.uniform(size=len(homogeneous_poiss)) * max_rate
spikes = homogeneous_poiss[u < rate_interpolated.flatten()]
Expand All @@ -383,8 +390,8 @@ def _analog_signal_linear_interp(signal, times):
Compute the linear interpolation of a signal at desired times.
Given the `signal` (neo.AnalogSignal) taking value `s0` and `s1` at two
consecutive time points `t0` and `t1` `(t0 < t1)`, for every time `t` in
`times`, such that `t0<t<=t1` is returned the value of the linear
consecutive time points `t0` and `t1` `(t0 < t1)`, for every time `t` in
`times`, such that `t0<t<=t1` is returned the value of the linear
interpolation, given by:
`s = ((s1 - s0) / (t1 - t0)) * t + s0`.
Expand All @@ -403,10 +410,10 @@ def _analog_signal_linear_interp(signal, times):
Notes
-----
If `signal` has sampling period `dt=signal.sampling_period`, its values
are defined at `t=signal.times`, such that `t[i] = signal.t_start + i * dt`
The last of such times is lower than
signal.t_stop`:t[-1] = signal.t_stop - dt`.
If `signal` has sampling period `dt=signal.sampling_period`, its values
are defined at `t=signal.times`, such that `t[i] = signal.t_start + i * dt`
The last of such times is lower than
signal.t_stop`:t[-1] = signal.t_stop - dt`.
For the interpolation at times t such that `t[-1] <= t <= signal.t_stop`,
the value of `signal` at `signal.t_stop` is taken to be that
at time `t[-1]`.
Expand All @@ -432,6 +439,7 @@ def _analog_signal_linear_interp(signal, times):
out = (y1 + m * (times - times_extended[time_ids])) * signal.units
return out.rescale(signal.units)


def homogeneous_gamma_process(a, b, t_start=0.0 * ms, t_stop=1000.0 * ms,
as_array=False):
"""
Expand Down Expand Up @@ -1060,5 +1068,111 @@ def compound_poisson_process(rate, A, t_stop, shift=None, t_start=0 * ms):
for cp in cpp]
return cpp


# Alias for the compound poisson process
cpp = compound_poisson_process


def homogeneous_poisson_process_with_refr_period(rate,
refr_period=2. * ms,
t_start=0.0 * ms,
t_stop=1000.0 * ms,
as_array=False):
"""
Returns a spike train whose spikes are a realization of a Poisson process
with the given rate and refractory period starting at time `t_start` and
stopping time `t_stop`.
All numerical values should be given as Quantities, e.g. 100*Hz.
Parameters
----------
rate : Quantity
Quantity scalar with dimension 1/time
The rate of the discharge.
refr_period : Quantity
Quantity scalar with dimension time
The time period the after one spike no other spike is
emitted.
Default: 2 ms
t_start : Quantity
Quantity scalar with dimension time
The beginning of the spike train.
Default: 0 ms
t_stop : Quantity
Quantity scalar with dimension time
The end of the spike train.
Default: 1000 ms
as_array : bool
If True, a NumPy array of sorted spikes is returned,
rather than a SpikeTrain object.
Default: False
Returns
-------
st : SpikeTrain or np.ndarray
Array of spike times.
If `as_array` is False, the output is wrapped in `neo.SpikeTrain`.
Raises
------
ValueError
If one of `rate`, `refr_period`, `t_start` and `t_stop` is not
of type `pq.Quantity`.
If `t_stop <= t_start`.
If the period between two successive spikes (`1 / rate`) is
smaller than the `refr_period`.
Examples
--------
>>> from quantities import Hz, ms
>>> spikes = homogeneous_poisson_process_with_refr_period(
50*Hz, 3*ms, 0*ms, 1000*ms)
>>> spikes = homogeneous_poisson_process_with_refr_period(
20*Hz, 5*ms, 5000*ms, 10000*ms, as_array=True)
"""
if not isinstance(t_start, Quantity) or not isinstance(t_stop, Quantity):
raise ValueError("t_start and t_stop must be of type pq.Quantity")
if not isinstance(refr_period, Quantity):
raise ValueError("refr_period must be of type pq.Quantity")
if not isinstance(rate, Quantity):
raise ValueError("rate must be of type pq.Quantity")

if t_stop.units != t_start.units:
t_stop = t_stop.rescale(t_start.units)

if t_stop <= t_start:
raise ValueError("t_stop must be larger than t_start")

rate_mag = rate.rescale(1 / t_start.units).magnitude
refr_period_mag = refr_period.rescale(t_start.units).magnitude

if 1. / rate_mag <= refr_period_mag:
raise ValueError("Period between two successive spikes must be larger"
"than the refractory period. Decrease either the"
"firing rate or the refractory period.")

duration = (t_stop - t_start).magnitude
mean_spike_count = rate_mag * duration
spike_count = np.random.poisson(lam=mean_spike_count)

# Check that the number of spikes drawn from the Poisson distribution,
# can fit in the duration regarding the refractory period.
spike_count = min(spike_count, math.ceil(duration / refr_period))

# Due to the refractory period the effective space in where the spikes
# can be placed is shortened by (spike_count - 1) times the
# refr. period.
eff_duration = duration - (spike_count - 1) * refr_period_mag

# In this effective time interval each spike is placed uniformly,
# and after sorting to each ISI one refractory period is added
st = np.random.uniform(high=eff_duration, size=spike_count)
st.sort()
st += refr_period_mag * np.arange(spike_count)
st += t_start.magnitude
if not as_array:
return SpikeTrain(st, t_start=t_start, t_stop=t_stop,
units=t_start.units)
return st
Loading

0 comments on commit 059ccf5

Please sign in to comment.