Skip to content

Commit

Permalink
Boundary correction rate for firing rate estimator with Gaussian KDE (#…
Browse files Browse the repository at this point in the history
…414)

Co-authored-by: stellalessandra <[email protected]>
  • Loading branch information
pbouss and stellalessandra authored Mar 25, 2022
1 parent 392933a commit 4979f25
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 24 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,3 @@ install:

script:
pytest --cov=elephant --import-mode=importlib

43 changes: 39 additions & 4 deletions elephant/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
import numpy as np
import quantities as pq
import scipy.stats
from scipy.special import erf

import elephant.conversion as conv
import elephant.kernels as kernels
Expand Down Expand Up @@ -601,7 +602,7 @@ def lvr(time_intervals, R=5*pq.ms, with_nan=False):
@deprecated_alias(spiketrain='spiketrains')
def instantaneous_rate(spiketrains, sampling_period, kernel='auto',
cutoff=5.0, t_start=None, t_stop=None, trim=False,
center_kernel=True):
center_kernel=True, border_correction=False):
"""
Estimates instantaneous firing rate by kernel convolution.
Expand All @@ -625,10 +626,12 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto',
triangular, epanechnikovlike, gaussian, laplacian, exponential, and
alpha function.
If 'auto', the optimized kernel width for the rate estimation is
calculated according to :cite:`statistics-Shimazaki2010_171` and with
this width a gaussian kernel is constructed. Automatized calculation
of the kernel width is not available for other than gaussian kernel
calculated according to :cite:`statistics-Shimazaki2010_171` and a
Gaussian kernel is constructed with this width. Automatized calculation
of the kernel width is not available for other than Gaussian kernel
shapes.
Note: The kernel width is not adaptive, i.e., it is calculated as
global optimum across the data.
Default: 'auto'
cutoff : float, optional
This factor determines the cutoff of the probability distribution of
Expand Down Expand Up @@ -665,6 +668,14 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto',
spike. If False, no adjustment is performed such that the spike sits at
the origin of the kernel.
Default: True
border_correction : bool, optional
Apply a border correction to prevent underestimating the firing rates
at the borders of the spike trains, i.e., close to t_start and t_stop.
The correction is done by estimating the mass of the kernel outside
these spike train borders under the assumption that the rate does not
change strongly.
Only possible in the case of a Gaussian kernel.
Default: False
Returns
-------
Expand Down Expand Up @@ -766,6 +777,12 @@ def optimal_kernel(st):
"instantaneous rate from input data.")
return kernels.GaussianKernel(width_sigma * st.units)

if border_correction and not \
(kernel == 'auto' or isinstance(kernel, kernels.GaussianKernel)):
raise ValueError(
'The border correction is only implemented'
' for Gaussian kernels.')

if isinstance(spiketrains, neo.SpikeTrain):
if kernel == 'auto':
kernel = optimal_kernel(spiketrains)
Expand Down Expand Up @@ -899,6 +916,24 @@ def optimal_kernel(st):
units=pq.Hz, t_start=t_start, t_stop=t_stop,
kernel=kernel_annotation)

if border_correction:
sigma = kernel.sigma.simplified.magnitude
times = rate.times.simplified.magnitude
correction_factor = 2 / (
erf((t_stop.simplified.magnitude - times) / (
np.sqrt(2.) * sigma))
- erf((t_start.simplified.magnitude - times) / (
np.sqrt(2.) * sigma)))

rate *= correction_factor[:, None]

duration = t_stop.simplified.magnitude - t_start.simplified.magnitude
# ensure integral over firing rate yield the exact number of spikes
for i, spiketrain in enumerate(spiketrains):
if len(spiketrain) > 0:
rate[:, i] *= len(spiketrain) /\
(np.mean(rate[:, i]).magnitude * duration)

return rate


Expand Down
5 changes: 1 addition & 4 deletions elephant/test/test_spade.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
from __future__ import division

import unittest
import random

import neo
from neo.core.spiketrainlist import SpikeTrainList
import numpy as np
import quantities as pq
from numpy.testing.utils import assert_array_equal
from numpy.testing import assert_array_equal

import elephant.conversion as conv
import elephant.spade as spade
Expand Down Expand Up @@ -289,8 +288,6 @@ def test_parameters(self):
elements_msip_max_spikes = []
for out in output_msip_max_spikes:
elements_msip_max_spikes.append(out['neurons'])
elements_msip_max_spikes = sorted(
elements_msip_max_spikes, key=len)
lags_msip_max_spikes = []
for out in output_msip_max_spikes:
lags_msip_max_spikes.append(list(out['lags'].magnitude))
Expand Down
82 changes: 67 additions & 15 deletions elephant/test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
import scipy.integrate as spint
from numpy.testing import assert_array_almost_equal, assert_array_equal, \
assert_array_less

import elephant.kernels as kernels
from elephant import statistics
from elephant.spike_train_generation import homogeneous_poisson_process
from elephant.spike_train_generation import StationaryPoissonProcess


class isi_TestCase(unittest.TestCase):
Expand Down Expand Up @@ -139,7 +138,8 @@ def test_mean_firing_rate_with_spiketrain(self):

def test_mean_firing_rate_typical_use_case(self):
np.random.seed(92)
st = homogeneous_poisson_process(rate=100 * pq.Hz, t_stop=100 * pq.s)
st = StationaryPoissonProcess(
rate=100 * pq.Hz, t_stop=100 * pq.s).generate_spiketrain()
rate1 = statistics.mean_firing_rate(st)
rate2 = statistics.mean_firing_rate(st, t_start=st.t_start,
t_stop=st.t_stop)
Expand Down Expand Up @@ -580,6 +580,9 @@ def test_rate_estimation_consistency(self):
kernels_available.append('auto')
kernel_resolution = 0.01 * pq.s
for kernel in kernels_available:
border_correction = False
if isinstance(kernel, kernels.GaussianKernel):
border_correction = True
for center_kernel in (False, True):
rate_estimate = statistics.instantaneous_rate(
self.spike_train,
Expand All @@ -588,7 +591,9 @@ def test_rate_estimation_consistency(self):
t_start=self.st_tr[0] * pq.s,
t_stop=self.st_tr[1] * pq.s,
trim=False,
center_kernel=center_kernel)
center_kernel=center_kernel,
border_correction=border_correction
)
num_spikes = len(self.spike_train)
auc = spint.cumtrapz(
y=rate_estimate.magnitude[:, 0],
Expand Down Expand Up @@ -616,9 +621,9 @@ def test_not_center_kernel(self):
def test_regression_288(self):
np.random.seed(9)
sampling_period = 200 * pq.ms
spiketrain = homogeneous_poisson_process(10 * pq.Hz,
t_start=0 * pq.s,
t_stop=10 * pq.s)
spiketrain = StationaryPoissonProcess(
10 * pq.Hz, t_start=0 * pq.s, t_stop=10 * pq.s
).generate_spiketrain()
kernel = kernels.AlphaKernel(sigma=5 * pq.ms, invert=True)
# check that instantaneous_rate "works" for kernels with small sigma
# without triggering an incomprehensible error
Expand All @@ -636,9 +641,9 @@ def test_small_kernel_sigma(self):
sampling_period = 200 * pq.ms
sigma = 5 * pq.ms
rate_expected = 10 * pq.Hz
spiketrain = homogeneous_poisson_process(rate_expected,
t_start=0 * pq.s,
t_stop=10 * pq.s)
spiketrain = StationaryPoissonProcess(
rate_expected, t_start=0 * pq.s, t_stop=10 * pq.s
).generate_spiketrain()
kernel_types = tuple(
kern_cls for kern_cls in kernels.__dict__.values()
if isinstance(kern_cls, type) and
Expand Down Expand Up @@ -777,8 +782,8 @@ def test_instantaneous_rate_regression_245(self):
def test_instantaneous_rate_grows_with_sampling_period(self):
np.random.seed(0)
rate_expected = 10 * pq.Hz
spiketrain = homogeneous_poisson_process(rate=rate_expected,
t_stop=10 * pq.s)
spiketrain = StationaryPoissonProcess(
rate=rate_expected, t_stop=10 * pq.s).generate_spiketrain()
kernel = kernels.GaussianKernel(sigma=100 * pq.ms)
rates_mean = []
for sampling_period in np.linspace(1, 1000, num=10) * pq.ms:
Expand Down Expand Up @@ -842,6 +847,51 @@ def test_annotations(self):
self.assertIn('kernel', rate.annotations)
self.assertEqual(rate.annotations['kernel'], kernel_annotation)

def test_border_correction(self):
np.random.seed(0)
n_spiketrains = 125
rate = 50. * pq.Hz
t_start = 0. * pq.ms
t_stop = 1000. * pq.ms

sampling_period = 0.1 * pq.ms

trial_list = StationaryPoissonProcess(
rate=rate, t_start=t_start, t_stop=t_stop
).generate_n_spiketrains(n_spiketrains)

for correction in (True, False):
rates = []
for trial in trial_list:
# calculate the instantaneous rate, discard extra dimension
instantaneous_rate = statistics.instantaneous_rate(
spiketrains=trial,
sampling_period=sampling_period,
kernel='auto',
border_correction=correction
)
rates.append(instantaneous_rate)

# The average estimated rate gives the average estimated value of
# the firing rate in each time bin.
# Note: the indexing [:, 0] is necessary to get the output an
# one-dimensional array.
average_estimated_rate = np.mean(rates, axis=0)[:, 0]

rtol = 0.05 # Five percent of tolerance

if correction:
self.assertLess(np.max(average_estimated_rate),
(1. + rtol) * rate.item())
self.assertGreater(np.min(average_estimated_rate),
(1. - rtol) * rate.item())
else:
self.assertLess(np.max(average_estimated_rate),
(1. + rtol) * rate.item())
# The minimal rate deviates strongly in the uncorrected case.
self.assertLess(np.min(average_estimated_rate),
(1. - rtol) * rate.item())


class TimeHistogramTestCase(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -909,8 +959,9 @@ def test_time_histogram_output(self):

def test_annotations(self):
np.random.seed(1)
spiketrains = [homogeneous_poisson_process(
rate=10 * pq.Hz, t_stop=10 * pq.s) for _ in range(10)]
spiketrains = StationaryPoissonProcess(
rate=10 * pq.Hz, t_stop=10 * pq.s).generate_n_spiketrains(
n_spiketrains=10)
for output in ("counts", "mean", "rate"):
histogram = statistics.time_histogram(spiketrains,
bin_size=3 * pq.ms,
Expand All @@ -931,7 +982,8 @@ def test_complexity_pdf_deprecated(self):
spiketrain_a, spiketrain_b, spiketrain_c]
# runs the previous function which will be deprecated
targ = np.array([0.92, 0.01, 0.01, 0.06])
complexity = statistics.complexity_pdf(spiketrains, binsize=0.1*pq.s)
complexity = statistics.complexity_pdf(
spiketrains, bin_size=0.1*pq.s)
assert_array_equal(targ, complexity.magnitude[:, 0])
self.assertEqual(1, complexity.magnitude[:, 0].sum())
self.assertEqual(len(spiketrains)+1, len(complexity))
Expand Down

0 comments on commit 4979f25

Please sign in to comment.