-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathneuromldb.py
executable file
·164 lines (121 loc) · 6.31 KB
/
neuromldb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import sys, json, quantities
from scipy.interpolate import interp1d
import numpy as np
from neo import AnalogSignal
from neuronunit.models.static import StaticModel
import quantities as pq
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
if sys.version_info[0] >= 3:
import urllib.request as urllib
else:
import urllib
class NeuroMLDBModel(object):
def __init__(self, model_id = "NMLCL000086", waveform_list = None):
self.model_id = model_id
self.api_url = "https://neuroml-db.org/api/" # See docs at: https://neuroml-db.org/api
if waveform_list is not None:
self.waveforms = waveform_list
else:
self.waveforms = None
self.waveform_signals = {}
self.url_responses = {}
def read_api_url(self, url):
if url not in self.url_responses:
##
# works on linux but not OSX
##
response = urllib.urlopen(url).read()
if sys.version_info[0] >= 3:
response = response.decode("utf-8")
self.url_responses[url] = json.loads(response)
return self.url_responses[url]
def fetch_waveform_list(self):
# Fetch the list of waveforms from the API and cache the result
if not self.waveforms:
data = self.read_api_url(self.api_url + "model?id=" + str(self.model_id))
self.waveforms = data["waveform_list"]
return self.waveforms
def fetch_waveform_as_AnalogSignal(self, waveform_id, resolution_ms = 0.01, units = "mV"):
#print('gets to b')
# If signal not in cache
if waveform_id not in self.waveform_signals:
# Load api URL into Python
#import pdb; pdb.set_trace(
#import pdb; pdb.set_trace()
data = self.read_api_url(self.api_url + "waveform?id=" + str(waveform_id))
# Get time and signal values (from CSV format)
t = np.array(data["Times"].split(','),float)
signal = np.array(data["Variable_Values"].split(','),float)
# Interpolate to regularly sampled series (API returns irregularly sampled)
sig = interp1d(t,signal,fill_value="extrapolate")
signal = sig(np.arange(min(t),max(t),resolution_ms))
# Convert to neo.AnalogSignal
signal = AnalogSignal(signal,units=units, sampling_period=resolution_ms*quantities.ms)
starts_from_ss = next(w for w in self.waveforms if w["ID"] == waveform_id)["Starts_From_Steady_State"] == 1
if starts_from_ss:
rest_wave = self.get_steady_state_waveform()
t = np.concatenate((rest_wave.times, signal.times + rest_wave.t_stop)) * quantities.s
v = np.concatenate((np.array(rest_wave), np.array(signal))) * quantities.mV
signal = AnalogSignal(v, units=units, sampling_period=resolution_ms * quantities.ms)
self.waveform_signals[waveform_id] = signal
return self.waveform_signals[waveform_id]
def get_steady_state_waveform(self):
if not hasattr(self, "steady_state_waveform") or self.steady_state_waveform is None:
for w in self.waveforms:
if w["Protocol_ID"] == "STEADY_STATE" and w["Variable_Name"] == "Voltage":
self.steady_state_waveform = self.fetch_waveform_as_AnalogSignal(w["ID"])
return self.steady_state_waveform
raise Exception("Did not find the resting waveform." +
" See " + self.api_url + "model?id=" + self.model_id +
" for the list of available model waveforms.")
return self.steady_state_waveform
def get_waveform_by_current(self, amplitude_nA):
#import pdb; pdb.set_trace()
for w in self.waveforms:
if w["Variable_Name"] == "Voltage":
#import pdb; pdb.set_trace()
wave_amp = self.get_waveform_current_amplitude(w)
#import pdb; pdb.set_trace()
if ((amplitude_nA < 0 * pq.nA and w["Protocol_ID"] == "SQUARE") or
(amplitude_nA >= 0 * pq.nA and w["Protocol_ID"] == "LONG_SQUARE")) \
and amplitude_nA == wave_amp:
print(w["ID"])
return self.fetch_waveform_as_AnalogSignal(w["ID"])
raise Exception("Did not find a Voltage waveform with injected " + str(amplitude_nA) +
". See " + self.api_url + "model?id=" + self.model_id +
" for the list of available model waveforms.")
def get_druckmann2013_standard_current(self):
currents = []
for w in self.waveforms:
if w["Protocol_ID"] == "LONG_SQUARE" and w["Variable_Name"] == "Voltage":
currents.append(self.get_waveform_current_amplitude(w))
if len(currents) != 4:
raise Exception("The LONG_SQUARE protocol for the model should have 4 waveforms")
return [currents[-2]] # 2nd to last one is RBx1.5 waveform
def get_druckmann2013_strong_current(self):
currents = []
for w in self.waveforms:
if w["Protocol_ID"] == "LONG_SQUARE" and w["Variable_Name"] == "Voltage":
currents.append(self.get_waveform_current_amplitude(w))
if len(currents) != 4:
raise Exception("The LONG_SQUARE protocol for the model should have 4 waveforms")
return [currents[-1]] # The last one is RBx3 waveform
def get_druckmann2013_input_resistance_currents(self):
currents = []
# Find and return negative square current injections
for w in self.waveforms:
if w["Protocol_ID"] == "SQUARE" and w["Variable_Name"] == "Voltage":
amp = self.get_waveform_current_amplitude(w)
if amp < 0 * pq.nA:
currents.append(amp)
return currents
def get_waveform_current_amplitude(self, waveform):
return float(waveform["Waveform_Label"].replace(" nA", "")) * pq.nA
class NeuroMLDBStaticModel(StaticModel):
def __init__(self, model_id, **params):
self.nmldb_model = NeuroMLDBModel(model_id)
self.nmldb_model.fetch_waveform_list()
def inject_square_current(self, current):
self.vm = self.nmldb_model.get_waveform_by_current(current["amplitude"])
return self.vm