Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Postsynaptic model target #458

Merged
merged 4 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/genn/genn/synapseGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class GENN_EXPORT SynapseGroup
and only applies to extra global parameters which are pointers. */
void setPSExtraGlobalParamLocation(const std::string &paramName, VarLocation loc);

//! Set name of neuron input variable postsynaptic model will target
/*! This should either be 'Isyn' or the name of one of the postsynaptic neuron's additional input variables. */
void setPSTargetVar(const std::string &varName);

//! Set location of sparse connectivity initialiser extra global parameter
/*! This is ignored for simulations on hardware with a single memory space
and only applies to extra global parameters which are pointers. */
Expand Down Expand Up @@ -198,6 +202,10 @@ class GENN_EXPORT SynapseGroup
/*! This is only used by extra global parameters which are pointers*/
VarLocation getPSExtraGlobalParamLocation(size_t index) const{ return m_PSExtraGlobalParamLocation.at(index); }

//! Get name of neuron input variable postsynaptic model will target
/*! This will either be 'Isyn' or the name of one of the postsynaptic neuron's additional input variables. */
const std::string &getPSTargetVar() const{ return m_PSTargetVar; }

//! Get location of sparse connectivity initialiser extra global parameter by name
/*! This is only used by extra global parameters which are pointers*/
VarLocation getSparseConnectivityExtraGlobalParamLocation(const std::string &paramName) const;
Expand Down Expand Up @@ -454,4 +462,8 @@ class GENN_EXPORT SynapseGroup
//! Name of the synapse group in which postsynaptic model is located
/*! This may not be the name of this group if it has been merged*/
std::string m_PSModelTargetName;

//! Name of neuron input variable postsynaptic model will target
/*! This should either be 'Isyn' or the name of one of the postsynaptic neuron's additional input variables. */
std::string m_PSTargetVar;
};
11 changes: 11 additions & 0 deletions pygenn/genn_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,17 @@ def has_individual_postsynaptic_vars(self):
"""Tests whether synaptic connectivity has
individual postsynaptic model variables"""
return (self.matrix_type & SynapseMatrixWeight_INDIVIDUAL_PSM) != 0

@property
def ps_target_var(self):
"""Gets name of neuron input variable postsynaptic model will target"""
return self.pop.get_pstarget_var()

@ps_output_var.setter
def ps_target_var(self, var):
"""Sets name of neuron input variable postsynaptic model will target"""
self.pop.set_pstarget_var(var)


def set_sparse_connections(self, pre_indices, post_indices):
"""Set ragged format connections between two groups of neurons
Expand Down
3 changes: 3 additions & 0 deletions src/genn/genn/code_generator/generateNeuronUpdate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ void CodeGenerator::generateNeuronUpdate(const filesystem::path &outputPath, con

Substitutions inSynSubs(&neuronSubs);
inSynSubs.addVarSubstitution("inSyn", "linSyn");

// Allow synapse group's PS output var to override what Isyn points to
inSynSubs.addVarSubstitution("Isyn", sg->getPSTargetVar(), true);

if (sg->getMatrixType() & SynapseMatrixWeight::INDIVIDUAL_PSM) {
inSynSubs.addVarNameSubstitution(psm->getVars(), "", "lps");
Expand Down
20 changes: 19 additions & 1 deletion src/genn/genn/synapseGroup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,21 @@ void SynapseGroup::setPSVarLocation(const std::string &varName, VarLocation loc)
m_PSVarLocation[getPSModel()->getVarIndex(varName)] = loc;
}
//----------------------------------------------------------------------------
void SynapseGroup::setPSTargetVar(const std::string &varName)
{
// If varname is either 'ISyn' or name of target neuron group additional input variable, store
const auto additionalInputVars = getTrgNeuronGroup()->getNeuronModel()->getAdditionalInputVars();
if(varName == "Isyn" ||
std::find_if(additionalInputVars.cbegin(), additionalInputVars.cend(),
[&varName](const Models::Base::ParamVal &v){ return (v.name == varName); }) != additionalInputVars.cend())
{
m_PSTargetVar = varName;
}
else {
throw std::runtime_error("Target neuron group has no input variable '" + varName + "'");
}
}
//----------------------------------------------------------------------------
void SynapseGroup::setPSExtraGlobalParamLocation(const std::string &paramName, VarLocation loc)
{
const size_t extraGlobalParamIndex = getPSModel()->getExtraGlobalParamIndex(paramName);
Expand Down Expand Up @@ -437,7 +452,8 @@ SynapseGroup::SynapseGroup(const std::string &name, SynapseMatrixType matrixType
m_WUPostVarLocation(wuPostVarInitialisers.size(), defaultVarLocation), m_WUExtraGlobalParamLocation(wu->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation),
m_PSVarLocation(psVarInitialisers.size(), defaultVarLocation), m_PSExtraGlobalParamLocation(ps->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation),
m_ConnectivityInitialiser(connectivityInitialiser), m_SparseConnectivityLocation(defaultSparseConnectivityLocation),
m_ConnectivityExtraGlobalParamLocation(connectivityInitialiser.getSnippet()->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation), m_PSModelTargetName(name)
m_ConnectivityExtraGlobalParamLocation(connectivityInitialiser.getSnippet()->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation), m_PSModelTargetName(name),
m_PSTargetVar("Isyn")
{
// Validate names
Utils::validatePopName(name, "Synapse group");
Expand Down Expand Up @@ -655,6 +671,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getPSHashDigest() const
Utils::updateHash(getPSModel()->getHashDigest(), hash);
Utils::updateHash(getMaxDendriticDelayTimesteps(), hash);
Utils::updateHash((getMatrixType() & SynapseMatrixWeight::INDIVIDUAL_PSM), hash);
Utils::updateHash(getPSTargetVar(), hash);
return hash.get_digest();
}
//----------------------------------------------------------------------------
Expand All @@ -664,6 +681,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getPSLinearCombineHashDige
Utils::updateHash(getPSModel()->getHashDigest(), hash);
Utils::updateHash(getMaxDendriticDelayTimesteps(), hash);
Utils::updateHash((getMatrixType() & SynapseMatrixWeight::INDIVIDUAL_PSM), hash);
Utils::updateHash(getPSTargetVar(), hash);
Utils::updateHash(getPSParams(), hash);
Utils::updateHash(getPSDerivedParams(), hash);
return hash.get_digest();
Expand Down
65 changes: 65 additions & 0 deletions tests/unit/synapseGroup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,47 @@ class Sum : public CustomUpdateModels::Base
{"b", "scalar", VarAccessMode::READ_ONLY}});
};
IMPLEMENT_MODEL(Sum);

class LIFAdditional : public NeuronModels::Base
{
public:
DECLARE_MODEL(LIFAdditional, 7, 2);

SET_ADDITIONAL_INPUT_VARS({{"Isyn2", "scalar", "$(Ioffset)"}});
SET_SIM_CODE(
"if ($(RefracTime) <= 0.0) {\n"
" scalar alpha = ($(Isyn2) * $(Rmembrane)) + $(Vrest);\n"
" $(V) = alpha - ($(ExpTC) * (alpha - $(V)));\n"
"}\n"
"else {\n"
" $(RefracTime) -= DT;\n"
"}\n"
);

SET_THRESHOLD_CONDITION_CODE("$(RefracTime) <= 0.0 && $(V) >= $(Vthresh)");

SET_RESET_CODE(
"$(V) = $(Vreset);\n"
"$(RefracTime) = $(TauRefrac);\n");

SET_PARAM_NAMES({
"C", // Membrane capacitance
"TauM", // Membrane time constant [ms]
"Vrest", // Resting membrane potential [mV]
"Vreset", // Reset voltage [mV]
"Vthresh", // Spiking threshold [mV]
"Ioffset", // Offset current
"TauRefrac"});

SET_DERIVED_PARAMS({
{"ExpTC", [](const std::vector<double> &pars, double dt) { return std::exp(-dt / pars[1]); }},
{"Rmembrane", [](const std::vector<double> &pars, double) { return pars[1] / pars[0]; }}});

SET_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}});

SET_NEEDS_AUTO_REFRACTORY(false);
};
IMPLEMENT_MODEL(LIFAdditional);
} // Anonymous namespace

//--------------------------------------------------------------------------
Expand Down Expand Up @@ -766,3 +807,27 @@ TEST(SynapseGroup, SharedWeightSlaveInvalidMethods)
//setSparseConnectivityExtraGlobalParamLocation
//setMaxSourceConnections
}

TEST(SynapseGroup, InvalidPSOutputVar)
{
LIFAdditional::ParamValues paramVals(0.25, 10.0, 0.0, 0.0, 20.0, 0.0, 5.0);
LIFAdditional::VarValues varVals(0.0, 0.0);

ModelSpec model;
model.addNeuronPopulation<NeuronModels::SpikeSource>("Pre", 10, {}, {});
model.addNeuronPopulation<LIFAdditional>("Post", 10, paramVals, varVals);
auto *prePost = model.addSynapsePopulation<WeightUpdateModels::StaticPulse, PostsynapticModels::DeltaCurr>(
"PrePost", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY,
"Pre", "Post",
{}, { 1.0 },
{}, {});

prePost->setPSTargetVar("Isyn");
prePost->setPSTargetVar("Isyn2");
try {
prePost->setPSTargetVar("NonExistent");
FAIL();
}
catch (const std::runtime_error &) {
}
}