Skip to content

Commit

Permalink
Merge pull request #458 from genn-team/postsynaptic_model_target
Browse files Browse the repository at this point in the history
Postsynaptic model target
  • Loading branch information
neworderofjamie authored Sep 15, 2021
2 parents a7ddf5c + 2a6d482 commit b6a4f56
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 1 deletion.
12 changes: 12 additions & 0 deletions include/genn/genn/synapseGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class GENN_EXPORT SynapseGroup
and only applies to extra global parameters which are pointers. */
void setPSExtraGlobalParamLocation(const std::string &paramName, VarLocation loc);

//! Set name of neuron input variable postsynaptic model will target
/*! This should either be 'Isyn' or the name of one of the postsynaptic neuron's additional input variables. */
void setPSTargetVar(const std::string &varName);

//! Set location of sparse connectivity initialiser extra global parameter
/*! This is ignored for simulations on hardware with a single memory space
and only applies to extra global parameters which are pointers. */
Expand Down Expand Up @@ -198,6 +202,10 @@ class GENN_EXPORT SynapseGroup
/*! This is only used by extra global parameters which are pointers*/
VarLocation getPSExtraGlobalParamLocation(size_t index) const{ return m_PSExtraGlobalParamLocation.at(index); }

//! Get name of neuron input variable postsynaptic model will target
/*! This will either be 'Isyn' or the name of one of the postsynaptic neuron's additional input variables. */
const std::string &getPSTargetVar() const{ return m_PSTargetVar; }

//! Get location of sparse connectivity initialiser extra global parameter by name
/*! This is only used by extra global parameters which are pointers*/
VarLocation getSparseConnectivityExtraGlobalParamLocation(const std::string &paramName) const;
Expand Down Expand Up @@ -454,4 +462,8 @@ class GENN_EXPORT SynapseGroup
//! Name of the synapse group in which postsynaptic model is located
/*! This may not be the name of this group if it has been merged*/
std::string m_PSModelTargetName;

//! Name of neuron input variable postsynaptic model will target
/*! This should either be 'Isyn' or the name of one of the postsynaptic neuron's additional input variables. */
std::string m_PSTargetVar;
};
11 changes: 11 additions & 0 deletions pygenn/genn_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,17 @@ def has_individual_postsynaptic_vars(self):
"""Tests whether synaptic connectivity has
individual postsynaptic model variables"""
return (self.matrix_type & SynapseMatrixWeight_INDIVIDUAL_PSM) != 0

@property
def ps_target_var(self):
"""Gets name of neuron input variable postsynaptic model will target"""
return self.pop.get_pstarget_var()

@ps_output_var.setter
def ps_target_var(self, var):
"""Sets name of neuron input variable postsynaptic model will target"""
self.pop.set_pstarget_var(var)


def set_sparse_connections(self, pre_indices, post_indices):
"""Set ragged format connections between two groups of neurons
Expand Down
3 changes: 3 additions & 0 deletions src/genn/genn/code_generator/generateNeuronUpdate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ void CodeGenerator::generateNeuronUpdate(const filesystem::path &outputPath, con

Substitutions inSynSubs(&neuronSubs);
inSynSubs.addVarSubstitution("inSyn", "linSyn");

// Allow synapse group's PS output var to override what Isyn points to
inSynSubs.addVarSubstitution("Isyn", sg->getPSTargetVar(), true);

if (sg->getMatrixType() & SynapseMatrixWeight::INDIVIDUAL_PSM) {
inSynSubs.addVarNameSubstitution(psm->getVars(), "", "lps");
Expand Down
20 changes: 19 additions & 1 deletion src/genn/genn/synapseGroup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,21 @@ void SynapseGroup::setPSVarLocation(const std::string &varName, VarLocation loc)
m_PSVarLocation[getPSModel()->getVarIndex(varName)] = loc;
}
//----------------------------------------------------------------------------
void SynapseGroup::setPSTargetVar(const std::string &varName)
{
// If varname is either 'ISyn' or name of target neuron group additional input variable, store
const auto additionalInputVars = getTrgNeuronGroup()->getNeuronModel()->getAdditionalInputVars();
if(varName == "Isyn" ||
std::find_if(additionalInputVars.cbegin(), additionalInputVars.cend(),
[&varName](const Models::Base::ParamVal &v){ return (v.name == varName); }) != additionalInputVars.cend())
{
m_PSTargetVar = varName;
}
else {
throw std::runtime_error("Target neuron group has no input variable '" + varName + "'");
}
}
//----------------------------------------------------------------------------
void SynapseGroup::setPSExtraGlobalParamLocation(const std::string &paramName, VarLocation loc)
{
const size_t extraGlobalParamIndex = getPSModel()->getExtraGlobalParamIndex(paramName);
Expand Down Expand Up @@ -437,7 +452,8 @@ SynapseGroup::SynapseGroup(const std::string &name, SynapseMatrixType matrixType
m_WUPostVarLocation(wuPostVarInitialisers.size(), defaultVarLocation), m_WUExtraGlobalParamLocation(wu->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation),
m_PSVarLocation(psVarInitialisers.size(), defaultVarLocation), m_PSExtraGlobalParamLocation(ps->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation),
m_ConnectivityInitialiser(connectivityInitialiser), m_SparseConnectivityLocation(defaultSparseConnectivityLocation),
m_ConnectivityExtraGlobalParamLocation(connectivityInitialiser.getSnippet()->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation), m_PSModelTargetName(name)
m_ConnectivityExtraGlobalParamLocation(connectivityInitialiser.getSnippet()->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation), m_PSModelTargetName(name),
m_PSTargetVar("Isyn")
{
// Validate names
Utils::validatePopName(name, "Synapse group");
Expand Down Expand Up @@ -655,6 +671,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getPSHashDigest() const
Utils::updateHash(getPSModel()->getHashDigest(), hash);
Utils::updateHash(getMaxDendriticDelayTimesteps(), hash);
Utils::updateHash((getMatrixType() & SynapseMatrixWeight::INDIVIDUAL_PSM), hash);
Utils::updateHash(getPSTargetVar(), hash);
return hash.get_digest();
}
//----------------------------------------------------------------------------
Expand All @@ -664,6 +681,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getPSLinearCombineHashDige
Utils::updateHash(getPSModel()->getHashDigest(), hash);
Utils::updateHash(getMaxDendriticDelayTimesteps(), hash);
Utils::updateHash((getMatrixType() & SynapseMatrixWeight::INDIVIDUAL_PSM), hash);
Utils::updateHash(getPSTargetVar(), hash);
Utils::updateHash(getPSParams(), hash);
Utils::updateHash(getPSDerivedParams(), hash);
return hash.get_digest();
Expand Down
65 changes: 65 additions & 0 deletions tests/unit/synapseGroup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,47 @@ class Sum : public CustomUpdateModels::Base
{"b", "scalar", VarAccessMode::READ_ONLY}});
};
IMPLEMENT_MODEL(Sum);

class LIFAdditional : public NeuronModels::Base
{
public:
DECLARE_MODEL(LIFAdditional, 7, 2);

SET_ADDITIONAL_INPUT_VARS({{"Isyn2", "scalar", "$(Ioffset)"}});
SET_SIM_CODE(
"if ($(RefracTime) <= 0.0) {\n"
" scalar alpha = ($(Isyn2) * $(Rmembrane)) + $(Vrest);\n"
" $(V) = alpha - ($(ExpTC) * (alpha - $(V)));\n"
"}\n"
"else {\n"
" $(RefracTime) -= DT;\n"
"}\n"
);

SET_THRESHOLD_CONDITION_CODE("$(RefracTime) <= 0.0 && $(V) >= $(Vthresh)");

SET_RESET_CODE(
"$(V) = $(Vreset);\n"
"$(RefracTime) = $(TauRefrac);\n");

SET_PARAM_NAMES({
"C", // Membrane capacitance
"TauM", // Membrane time constant [ms]
"Vrest", // Resting membrane potential [mV]
"Vreset", // Reset voltage [mV]
"Vthresh", // Spiking threshold [mV]
"Ioffset", // Offset current
"TauRefrac"});

SET_DERIVED_PARAMS({
{"ExpTC", [](const std::vector<double> &pars, double dt) { return std::exp(-dt / pars[1]); }},
{"Rmembrane", [](const std::vector<double> &pars, double) { return pars[1] / pars[0]; }}});

SET_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}});

SET_NEEDS_AUTO_REFRACTORY(false);
};
IMPLEMENT_MODEL(LIFAdditional);
} // Anonymous namespace

//--------------------------------------------------------------------------
Expand Down Expand Up @@ -766,3 +807,27 @@ TEST(SynapseGroup, SharedWeightSlaveInvalidMethods)
//setSparseConnectivityExtraGlobalParamLocation
//setMaxSourceConnections
}

TEST(SynapseGroup, InvalidPSOutputVar)
{
LIFAdditional::ParamValues paramVals(0.25, 10.0, 0.0, 0.0, 20.0, 0.0, 5.0);
LIFAdditional::VarValues varVals(0.0, 0.0);

ModelSpec model;
model.addNeuronPopulation<NeuronModels::SpikeSource>("Pre", 10, {}, {});
model.addNeuronPopulation<LIFAdditional>("Post", 10, paramVals, varVals);
auto *prePost = model.addSynapsePopulation<WeightUpdateModels::StaticPulse, PostsynapticModels::DeltaCurr>(
"PrePost", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY,
"Pre", "Post",
{}, { 1.0 },
{}, {});

prePost->setPSTargetVar("Isyn");
prePost->setPSTargetVar("Isyn2");
try {
prePost->setPSTargetVar("NonExistent");
FAIL();
}
catch (const std::runtime_error &) {
}
}

0 comments on commit b6a4f56

Please sign in to comment.