diff --git a/include/genn/genn/synapseGroup.h b/include/genn/genn/synapseGroup.h index 2e111d3aa8..37559134cf 100644 --- a/include/genn/genn/synapseGroup.h +++ b/include/genn/genn/synapseGroup.h @@ -64,6 +64,10 @@ class GENN_EXPORT SynapseGroup and only applies to extra global parameters which are pointers. */ void setPSExtraGlobalParamLocation(const std::string ¶mName, VarLocation loc); + //! Set name of neuron input variable postsynaptic model will target + /*! This should either be 'Isyn' or the name of one of the postsynaptic neuron's additional input variables. */ + void setPSTargetVar(const std::string &varName); + //! Set location of sparse connectivity initialiser extra global parameter /*! This is ignored for simulations on hardware with a single memory space and only applies to extra global parameters which are pointers. */ @@ -198,6 +202,10 @@ class GENN_EXPORT SynapseGroup /*! This is only used by extra global parameters which are pointers*/ VarLocation getPSExtraGlobalParamLocation(size_t index) const{ return m_PSExtraGlobalParamLocation.at(index); } + //! Get name of neuron input variable postsynaptic model will target + /*! This will either be 'Isyn' or the name of one of the postsynaptic neuron's additional input variables. */ + const std::string &getPSTargetVar() const{ return m_PSTargetVar; } + //! Get location of sparse connectivity initialiser extra global parameter by name /*! This is only used by extra global parameters which are pointers*/ VarLocation getSparseConnectivityExtraGlobalParamLocation(const std::string ¶mName) const; @@ -454,4 +462,8 @@ class GENN_EXPORT SynapseGroup //! Name of the synapse group in which postsynaptic model is located /*! This may not be the name of this group if it has been merged*/ std::string m_PSModelTargetName; + + //! Name of neuron input variable postsynaptic model will target + /*! This should either be 'Isyn' or the name of one of the postsynaptic neuron's additional input variables. */ + std::string m_PSTargetVar; }; diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index 7f17d86bef..0147e04f92 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -817,6 +817,17 @@ def has_individual_postsynaptic_vars(self): """Tests whether synaptic connectivity has individual postsynaptic model variables""" return (self.matrix_type & SynapseMatrixWeight_INDIVIDUAL_PSM) != 0 + + @property + def ps_target_var(self): + """Gets name of neuron input variable postsynaptic model will target""" + return self.pop.get_pstarget_var() + + @ps_output_var.setter + def ps_target_var(self, var): + """Sets name of neuron input variable postsynaptic model will target""" + self.pop.set_pstarget_var(var) + def set_sparse_connections(self, pre_indices, post_indices): """Set ragged format connections between two groups of neurons diff --git a/src/genn/genn/code_generator/generateNeuronUpdate.cc b/src/genn/genn/code_generator/generateNeuronUpdate.cc index c01ffafae2..3eeafb65a1 100644 --- a/src/genn/genn/code_generator/generateNeuronUpdate.cc +++ b/src/genn/genn/code_generator/generateNeuronUpdate.cc @@ -256,6 +256,9 @@ void CodeGenerator::generateNeuronUpdate(const filesystem::path &outputPath, con Substitutions inSynSubs(&neuronSubs); inSynSubs.addVarSubstitution("inSyn", "linSyn"); + + // Allow synapse group's PS output var to override what Isyn points to + inSynSubs.addVarSubstitution("Isyn", sg->getPSTargetVar(), true); if (sg->getMatrixType() & SynapseMatrixWeight::INDIVIDUAL_PSM) { inSynSubs.addVarNameSubstitution(psm->getVars(), "", "lps"); diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index 18512e8624..c351b79a41 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -75,6 +75,21 @@ void SynapseGroup::setPSVarLocation(const std::string &varName, VarLocation loc) m_PSVarLocation[getPSModel()->getVarIndex(varName)] = loc; } //---------------------------------------------------------------------------- +void SynapseGroup::setPSTargetVar(const std::string &varName) +{ + // If varname is either 'ISyn' or name of target neuron group additional input variable, store + const auto additionalInputVars = getTrgNeuronGroup()->getNeuronModel()->getAdditionalInputVars(); + if(varName == "Isyn" || + std::find_if(additionalInputVars.cbegin(), additionalInputVars.cend(), + [&varName](const Models::Base::ParamVal &v){ return (v.name == varName); }) != additionalInputVars.cend()) + { + m_PSTargetVar = varName; + } + else { + throw std::runtime_error("Target neuron group has no input variable '" + varName + "'"); + } +} +//---------------------------------------------------------------------------- void SynapseGroup::setPSExtraGlobalParamLocation(const std::string ¶mName, VarLocation loc) { const size_t extraGlobalParamIndex = getPSModel()->getExtraGlobalParamIndex(paramName); @@ -437,7 +452,8 @@ SynapseGroup::SynapseGroup(const std::string &name, SynapseMatrixType matrixType m_WUPostVarLocation(wuPostVarInitialisers.size(), defaultVarLocation), m_WUExtraGlobalParamLocation(wu->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation), m_PSVarLocation(psVarInitialisers.size(), defaultVarLocation), m_PSExtraGlobalParamLocation(ps->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation), m_ConnectivityInitialiser(connectivityInitialiser), m_SparseConnectivityLocation(defaultSparseConnectivityLocation), - m_ConnectivityExtraGlobalParamLocation(connectivityInitialiser.getSnippet()->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation), m_PSModelTargetName(name) + m_ConnectivityExtraGlobalParamLocation(connectivityInitialiser.getSnippet()->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation), m_PSModelTargetName(name), + m_PSTargetVar("Isyn") { // Validate names Utils::validatePopName(name, "Synapse group"); @@ -655,6 +671,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getPSHashDigest() const Utils::updateHash(getPSModel()->getHashDigest(), hash); Utils::updateHash(getMaxDendriticDelayTimesteps(), hash); Utils::updateHash((getMatrixType() & SynapseMatrixWeight::INDIVIDUAL_PSM), hash); + Utils::updateHash(getPSTargetVar(), hash); return hash.get_digest(); } //---------------------------------------------------------------------------- @@ -664,6 +681,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getPSLinearCombineHashDige Utils::updateHash(getPSModel()->getHashDigest(), hash); Utils::updateHash(getMaxDendriticDelayTimesteps(), hash); Utils::updateHash((getMatrixType() & SynapseMatrixWeight::INDIVIDUAL_PSM), hash); + Utils::updateHash(getPSTargetVar(), hash); Utils::updateHash(getPSParams(), hash); Utils::updateHash(getPSDerivedParams(), hash); return hash.get_digest(); diff --git a/tests/unit/synapseGroup.cc b/tests/unit/synapseGroup.cc index b37eb190e7..177131a742 100644 --- a/tests/unit/synapseGroup.cc +++ b/tests/unit/synapseGroup.cc @@ -105,6 +105,47 @@ class Sum : public CustomUpdateModels::Base {"b", "scalar", VarAccessMode::READ_ONLY}}); }; IMPLEMENT_MODEL(Sum); + +class LIFAdditional : public NeuronModels::Base +{ +public: + DECLARE_MODEL(LIFAdditional, 7, 2); + + SET_ADDITIONAL_INPUT_VARS({{"Isyn2", "scalar", "$(Ioffset)"}}); + SET_SIM_CODE( + "if ($(RefracTime) <= 0.0) {\n" + " scalar alpha = ($(Isyn2) * $(Rmembrane)) + $(Vrest);\n" + " $(V) = alpha - ($(ExpTC) * (alpha - $(V)));\n" + "}\n" + "else {\n" + " $(RefracTime) -= DT;\n" + "}\n" + ); + + SET_THRESHOLD_CONDITION_CODE("$(RefracTime) <= 0.0 && $(V) >= $(Vthresh)"); + + SET_RESET_CODE( + "$(V) = $(Vreset);\n" + "$(RefracTime) = $(TauRefrac);\n"); + + SET_PARAM_NAMES({ + "C", // Membrane capacitance + "TauM", // Membrane time constant [ms] + "Vrest", // Resting membrane potential [mV] + "Vreset", // Reset voltage [mV] + "Vthresh", // Spiking threshold [mV] + "Ioffset", // Offset current + "TauRefrac"}); + + SET_DERIVED_PARAMS({ + {"ExpTC", [](const std::vector &pars, double dt) { return std::exp(-dt / pars[1]); }}, + {"Rmembrane", [](const std::vector &pars, double) { return pars[1] / pars[0]; }}}); + + SET_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); + + SET_NEEDS_AUTO_REFRACTORY(false); +}; +IMPLEMENT_MODEL(LIFAdditional); } // Anonymous namespace //-------------------------------------------------------------------------- @@ -766,3 +807,27 @@ TEST(SynapseGroup, SharedWeightSlaveInvalidMethods) //setSparseConnectivityExtraGlobalParamLocation //setMaxSourceConnections } + +TEST(SynapseGroup, InvalidPSOutputVar) +{ + LIFAdditional::ParamValues paramVals(0.25, 10.0, 0.0, 0.0, 20.0, 0.0, 5.0); + LIFAdditional::VarValues varVals(0.0, 0.0); + + ModelSpec model; + model.addNeuronPopulation("Pre", 10, {}, {}); + model.addNeuronPopulation("Post", 10, paramVals, varVals); + auto *prePost = model.addSynapsePopulation( + "PrePost", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + "Pre", "Post", + {}, { 1.0 }, + {}, {}); + + prePost->setPSTargetVar("Isyn"); + prePost->setPSTargetVar("Isyn2"); + try { + prePost->setPSTargetVar("NonExistent"); + FAIL(); + } + catch (const std::runtime_error &) { + } +}