Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate raster plot #623

Merged
merged 11 commits into from
Feb 21, 2023
45 changes: 45 additions & 0 deletions src/lava/utils/plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import numpy as np
import matplotlib.pyplot as plt


def raster_plot(spks, stride=6, fig=None, color='b', alpha=1):
ssgier marked this conversation as resolved.
Show resolved Hide resolved
"""Generate raster plot of spiking activity.

Parameters
----------
spks : np.ndarray shape (num_neurons, num_timesteps)
ssgier marked this conversation as resolved.
Show resolved Hide resolved
ssgier marked this conversation as resolved.
Show resolved Hide resolved
Spiking activity of neurons, a spike is indicated by a one
stride : int
Stride for plotting neurons
"""
ssgier marked this conversation as resolved.
Show resolved Hide resolved
num_neurons = spks.shape[0]
num_time_steps = spks.shape[1]
mathisrichter marked this conversation as resolved.
Show resolved Hide resolved

if stride >= num_neurons:
raise ValueError("Stride must be less than the number of neurons")
ssgier marked this conversation as resolved.
Show resolved Hide resolved

time_steps = np.arange(0, num_time_steps, 1)
if fig is None:
fig = plt.figure(figsize=(10, 5))
ssgier marked this conversation as resolved.
Show resolved Hide resolved

plt.xlim(-1, num_time_steps)
plt.yticks([])

plt.xlabel('Time steps')
plt.ylabel('Neurons')
ssgier marked this conversation as resolved.
Show resolved Hide resolved

for i in range(0, num_neurons, stride):
spike_times = time_steps[spks[i] == 1]
plt.plot(spike_times,
i * np.ones(spike_times.shape),
linestyle=' ',
marker='o',
markersize=1.5,
ssgier marked this conversation as resolved.
Show resolved Hide resolved
color=color,
alpha=alpha)

return fig
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"id": "3fd51524",
"metadata": {},
"source": [
"*Copyright (C) 2022 Intel Corporation*<br>\n",
"*Copyright (C) 2022-23 Intel Corporation*<br>\n",
"*SPDX-License-Identifier: BSD-3-Clause*<br>\n",
"*See: https://spdx.org/licenses/*\n",
"\n",
Expand Down Expand Up @@ -889,51 +889,6 @@
"To this end, we display neurons on the vertical axis and mark the time step when a neuron spiked."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "abac04ec",
"metadata": {},
"outputs": [],
"source": [
"def raster_plot(spks, stride=6, fig=None, color='b', alpha=1):\n",
" \"\"\"Generate raster plot of spiking activity.\n",
" \n",
" Parameters\n",
" ----------\n",
" \n",
" spks : np.ndarray shape (num_neurons, timesteps)\n",
" Spiking activity of neurons, a spike is indicated by a one \n",
" stride : int\n",
" Stride for plotting neurons\n",
" \"\"\"\n",
" num_time_steps = spks.shape[1]\n",
" assert stride < num_time_steps, \"Stride must be smaller than number of time steps\"\n",
" \n",
" time_steps = np.arange(0, num_time_steps, 1)\n",
" if fig is None:\n",
" fig = plt.figure(figsize=(10,5))\n",
" timesteps = spks.shape[1]\n",
" \n",
" plt.xlim(-1, num_time_steps)\n",
" plt.yticks([])\n",
" \n",
" plt.xlabel('Time steps')\n",
" plt.ylabel('Neurons')\n",
" \n",
" for i in range(0, dim, stride):\n",
" spike_times = time_steps[spks[i] == 1]\n",
" plt.plot(spike_times,\n",
" i * np.ones(spike_times.shape),\n",
" linestyle=' ',\n",
" marker='o',\n",
" markersize=1.5,\n",
" color=color,\n",
" alpha=alpha)\n",
" \n",
" return fig "
]
},
{
"cell_type": "code",
"execution_count": 18,
Expand All @@ -954,6 +909,7 @@
}
],
"source": [
"from lava.utils.plots import raster_plot\n",
"fig = raster_plot(spks=spks_balanced)"
]
},
Expand Down