diff --git a/include/genn/genn/postsynapticModels.h b/include/genn/genn/postsynapticModels.h
index 888ebd97b..d62c76f83 100644
--- a/include/genn/genn/postsynapticModels.h
+++ b/include/genn/genn/postsynapticModels.h
@@ -9,9 +9,7 @@
//----------------------------------------------------------------------------
// Macros
//----------------------------------------------------------------------------
-#define SET_DECAY_CODE(DECAY_CODE) virtual std::string getDecayCode() const override{ return DECAY_CODE; }
-#define SET_CURRENT_CONVERTER_CODE(CURRENT_CONVERTER_CODE) virtual std::string getApplyInputCode() const override{ return "Isyn += " CURRENT_CONVERTER_CODE ";"; }
-#define SET_APPLY_INPUT_CODE(APPLY_INPUT_CODE) virtual std::string getApplyInputCode() const override{ return APPLY_INPUT_CODE; }
+#define SET_SIM_CODE(SIM_CODE) virtual std::string getSimCode() const override{ return SIM_CODE; }
#define SET_NEURON_VAR_REFS(...) virtual VarRefVec getNeuronVarRefs() const override{ return __VA_ARGS__; }
//----------------------------------------------------------------------------
@@ -32,8 +30,7 @@ class GENN_EXPORT Base : public Models::Base
//! Gets names and types of model variable references
virtual VarRefVec getNeuronVarRefs() const{ return {}; }
- virtual std::string getDecayCode() const{ return ""; }
- virtual std::string getApplyInputCode() const{ return ""; }
+ virtual std::string getSimCode() const{ return ""; }
//----------------------------------------------------------------------------
// Public API
@@ -72,8 +69,7 @@ class GENN_EXPORT Init : public Snippet::Init
const std::unordered_map &getVarInitialisers() const{ return m_VarInitialisers; }
const std::unordered_map &getNeuronVarReferences() const{ return m_NeuronVarReferences; }
- const std::vector &getDecayCodeTokens() const{ return m_DecayCodeTokens; }
- const std::vector &getApplyInputCodeTokens() const{ return m_ApplyInputCodeTokens; }
+ const std::vector &getSimCodeTokens() const{ return m_SimCodeTokens; }
void finalise(double dt);
@@ -81,8 +77,7 @@ class GENN_EXPORT Init : public Snippet::Init
//------------------------------------------------------------------------
// Members
//------------------------------------------------------------------------
- std::vector m_DecayCodeTokens;
- std::vector m_ApplyInputCodeTokens;
+ std::vector m_SimCodeTokens;
std::unordered_map m_VarInitialisers;
std::unordered_map m_NeuronVarReferences;
@@ -99,9 +94,9 @@ class ExpCurr : public Base
public:
DECLARE_SNIPPET(ExpCurr);
- SET_DECAY_CODE("inSyn *= expDecay;");
-
- SET_CURRENT_CONVERTER_CODE("init * inSyn");
+ SET_SIM_CODE(
+ "injectCurrent(init * inSyn);\n"
+ "inSyn *= expDecay;\n");
SET_PARAMS({"tau"});
@@ -124,9 +119,9 @@ class ExpCond : public Base
public:
DECLARE_SNIPPET(ExpCond);
- SET_DECAY_CODE("inSyn*=expDecay;");
-
- SET_CURRENT_CONVERTER_CODE("inSyn * (E - V)");
+ SET_SIM_CODE(
+ "injectCurrent(inSyn * (E - V));\n"
+ "inSyn *= expDecay;\n");
SET_PARAMS({"tau", "E"});
@@ -145,6 +140,8 @@ class DeltaCurr : public Base
public:
DECLARE_SNIPPET(DeltaCurr);
- SET_CURRENT_CONVERTER_CODE("inSyn; inSyn = 0");
+ SET_SIM_CODE(
+ "injectCurrent(inSyn);\n"
+ "inSyn = 0.0;\n");
};
} // namespace GeNN::PostsynapticModels
diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py
index 0754a1c59..96b7f23ed 100644
--- a/pygenn/genn_model.py
+++ b/pygenn/genn_model.py
@@ -1020,8 +1020,8 @@ def create_neuron_model(class_name, params=None, param_names=None,
def create_postsynaptic_model(class_name, params=None, param_names=None,
var_name_types=None, neuron_var_refs=None,
- derived_params=None, decay_code=None,
- apply_input_code=None,
+ derived_params=None, sim_code=None,
+ decay_code=None, apply_input_code=None,
extra_global_params=None):
"""This helper function creates a custom PostsynapticModel class.
See also:
@@ -1042,21 +1042,21 @@ def create_postsynaptic_model(class_name, params=None, param_names=None,
derived_params -- list of pairs, where the first member is string
with name of the derived parameter and the second
should be a functor returned by create_dpf_class
+ sim_code -- string with the decay code
decay_code -- string with the decay code
apply_input_code -- string with the apply input code
extra_global_params -- list of pairs of strings with names and
types of additional parameters
"""
body = {}
-
- if decay_code is not None:
- body["get_decay_code"] =\
- lambda self: dedent(upgrade_code_string(decay_code, class_name))
-
- if apply_input_code is not None:
- body["get_apply_input_code"] =\
- lambda self: dedent(upgrade_code_string(apply_input_code,
- class_name))
+ if decay_code is not None or apply_input_code is not None:
+ raise RuntimeError("Creating postsynaptic models with seperate "
+ "'decay_code' and 'apply_code' code strings is no "
+ "longer supported. Please provide 'sim_code' using "
+ "the injectCurrent(X) function to provide input.")
+ if sim_code is not None:
+ body["get_sim_code"] =\
+ lambda self: dedent(upgrade_code_string(sim_code, class_name))
if var_name_types is not None:
body["get_vars"] = \
diff --git a/pygenn/src/genn.cc b/pygenn/src/genn.cc
index 370e7ce12..4641fcc47 100644
--- a/pygenn/src/genn.cc
+++ b/pygenn/src/genn.cc
@@ -173,8 +173,7 @@ class PyPostsynapticModelBase : public PySnippet
virtual std::vector getVars() const override{ PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_vars", getVars); }
virtual VarRefVec getNeuronVarRefs() const override { PYBIND11_OVERRIDE_NAME(VarRefVec, Base, "get_neuron_var_refs", getNeuronVarRefs); }
- virtual std::string getDecayCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_decay_code", getDecayCode); }
- virtual std::string getApplyInputCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_apply_input_code", getApplyInputCode); }
+ virtual std::string getSimCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_sim_code", getSimCode); }
};
//----------------------------------------------------------------------------
@@ -846,7 +845,7 @@ PYBIND11_MODULE(genn, m)
.def("get_additional_input_vars", &NeuronModels::Base::getAdditionalInputVars)
.def("is_auto_refractory_required", &NeuronModels::Base::isAutoRefractoryRequired);
-
+
//------------------------------------------------------------------------
// genn.PostsynapticModelBase
//------------------------------------------------------------------------
@@ -856,9 +855,8 @@ PYBIND11_MODULE(genn, m)
.def("get_vars", &PostsynapticModels::Base::getVars)
.def("get_neuron_var_refs", &PostsynapticModels::Base::getNeuronVarRefs)
- .def("get_decay_code", &PostsynapticModels::Base::getDecayCode)
- .def("get_apply_input_code", &PostsynapticModels::Base::getApplyInputCode);
-
+ .def("get_sim_code", &PostsynapticModels::Base::getSimCode);
+
//------------------------------------------------------------------------
// genn.WeightUpdateModelBase
//------------------------------------------------------------------------
diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc
index 619303d4e..9ec32e279 100644
--- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc
+++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc
@@ -120,9 +120,10 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env
// **TODO** naming convention
psmEnv.add(getScalarType(), "inSyn", "linSyn");
-
- // Allow synapse group's PS output var to override what Isyn points to
- psmEnv.add(getScalarType(), "Isyn", "$(_" + getArchetype().getPostTargetVar() + ")");
+
+ // Define inject current function
+ psmEnv.add(Type::ResolvedType::createFunction(Type::Void, {getScalarType()}),
+ "injectCurrent", "$(_" + getArchetype().getPostTargetVar() + ") += $(0)");
// Create an environment which caches variables in local variables if they are accessed
EnvironmentLocalVarCache varEnv(
@@ -133,11 +134,8 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env
});
// Pretty print code back to environment
- Transpiler::ErrorHandler applyInputErrorHandler("Synapse group '" + getArchetype().getName() + "' postsynaptic model apply input code");
- prettyPrintStatements(getArchetype().getPSInitialiser().getApplyInputCodeTokens(), getTypeContext(), varEnv, applyInputErrorHandler);
-
- Transpiler::ErrorHandler decayErrorHandler("Synapse group '" + getArchetype().getName() + "' postsynaptic model decay code");
- prettyPrintStatements(getArchetype().getPSInitialiser().getDecayCodeTokens(), getTypeContext(), varEnv, decayErrorHandler);
+ Transpiler::ErrorHandler errorHandler("Synapse group '" + getArchetype().getName() + "' postsynaptic model apply sim code");
+ prettyPrintStatements(getArchetype().getPSInitialiser().getSimCodeTokens(), getTypeContext(), varEnv, errorHandler);
// Write back linSyn
varEnv.printLine("$(_out_post)[" + ng.getVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, "$(id)") + "] = linSyn;");
diff --git a/src/genn/genn/neuronGroup.cc b/src/genn/genn/neuronGroup.cc
index 15da10fd9..249efddee 100644
--- a/src/genn/genn/neuronGroup.cc
+++ b/src/genn/genn/neuronGroup.cc
@@ -336,8 +336,7 @@ bool NeuronGroup::isSimRNGRequired() const
return std::any_of(getInSyn().cbegin(), getInSyn().cend(),
[](const SynapseGroupInternal *sg)
{
- return (Utils::isRNGRequired(sg->getPSInitialiser().getApplyInputCodeTokens()) ||
- Utils::isRNGRequired(sg->getPSInitialiser().getDecayCodeTokens()));
+ return Utils::isRNGRequired(sg->getPSInitialiser().getSimCodeTokens());
});
}
//----------------------------------------------------------------------------
diff --git a/src/genn/genn/postsynapticModels.cc b/src/genn/genn/postsynapticModels.cc
index fe131e61c..b5ada146e 100644
--- a/src/genn/genn/postsynapticModels.cc
+++ b/src/genn/genn/postsynapticModels.cc
@@ -22,8 +22,7 @@ boost::uuids::detail::sha1::digest_type Base::getHashDigest() const
Snippet::Base::updateHash(hash);
Utils::updateHash(getVars(), hash);
Utils::updateHash(getNeuronVarRefs(), hash);
- Utils::updateHash(getDecayCode(), hash);
- Utils::updateHash(getApplyInputCode(), hash);
+ Utils::updateHash(getSimCode(), hash);
return hash.get_digest();
}
//----------------------------------------------------------------------------
@@ -60,13 +59,12 @@ Init::Init(const Base *snippet, const std::unordered_mapgetNeuronVarRefs());
// Scan code tokens
- m_DecayCodeTokens = Utils::scanCode(getSnippet()->getDecayCode(), "Postsynaptic model decay code");
- m_ApplyInputCodeTokens = Utils::scanCode(getSnippet()->getApplyInputCode(), "Postsynaptic model apply input code");
+ m_SimCodeTokens = Utils::scanCode(getSnippet()->getSimCode(), "Postsynaptic model sim code");
}
//----------------------------------------------------------------------------
bool Init::isRNGRequired() const
{
- return (Utils::isRNGRequired(m_DecayCodeTokens) || Utils::isRNGRequired(m_ApplyInputCodeTokens));
+ return Utils::isRNGRequired(m_SimCodeTokens);
}
//----------------------------------------------------------------------------
bool Init::isVarInitRequired() const
diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc
index 43754d8e2..023aca191 100644
--- a/src/genn/genn/synapseGroup.cc
+++ b/src/genn/genn/synapseGroup.cc
@@ -575,13 +575,8 @@ bool SynapseGroup::canPSBeFused(const NeuronGroup *ng) const
// **NOTE** this is kind of silly as, if it's not referenced in either of
// these code strings, there wouldn't be a lot of point in a PSM EGP existing!
for(const auto &egp : getPSInitialiser().getSnippet()->getExtraGlobalParams()) {
- // If this EGP is referenced in decay code, return false
- if(Utils::isIdentifierReferenced(egp.name, getPSInitialiser().getDecayCodeTokens())) {
- return false;
- }
-
- // If this EGP is referenced in apply input code, return false
- if(Utils::isIdentifierReferenced(egp.name, getPSInitialiser().getApplyInputCodeTokens())) {
+ // If this EGP is referenced in sim code, return false
+ if(Utils::isIdentifierReferenced(egp.name, getPSInitialiser().getSimCodeTokens())) {
return false;
}
}
@@ -590,13 +585,8 @@ bool SynapseGroup::canPSBeFused(const NeuronGroup *ng) const
for(const auto &p : getPSInitialiser().getSnippet()->getParams()) {
// If parameter is dynamic
if(isPSParamDynamic(p.name)) {
- // If this parameter is referenced in decay code, return false
- if(Utils::isIdentifierReferenced(p.name, getPSInitialiser().getDecayCodeTokens())) {
- return false;
- }
-
- // If this parameter is referenced in apply input code, return false
- if(Utils::isIdentifierReferenced(p.name, getPSInitialiser().getApplyInputCodeTokens())) {
+ // If this parameter is referenced in sim code, return false
+ if(Utils::isIdentifierReferenced(p.name, getPSInitialiser().getSimCodeTokens())) {
return false;
}
}
diff --git a/tests/features/test_dynamic_param.py b/tests/features/test_dynamic_param.py
index 58b40007b..1775ee062 100644
--- a/tests/features/test_dynamic_param.py
+++ b/tests/features/test_dynamic_param.py
@@ -28,15 +28,12 @@ def test_dynamic_param(make_model, backend, precision):
postsynaptic_model = create_postsynaptic_model(
"postsynaptic",
- decay_code=
+ sim_code=
"""
+ injectCurrent(inSyn);
psmX = t + psmShift + psmInput;
$(inSyn) = 0;
""",
- apply_input_code=
- """
- $(Isyn) += $(inSyn);
- """,
params=["psmInput"],
var_name_types=[("psmX", "scalar"), ("psmShift", "scalar")])
diff --git a/tests/unit/modelSpec.cc b/tests/unit/modelSpec.cc
index 7dbeb0137..922fdd687 100644
--- a/tests/unit/modelSpec.cc
+++ b/tests/unit/modelSpec.cc
@@ -16,11 +16,10 @@ class AlphaCurr : public PostsynapticModels::Base
public:
DECLARE_SNIPPET(AlphaCurr);
- SET_DECAY_CODE(
- "x = (DT * expDecay * inSyn * init) + (expDecay * x);\n"
- "inSyn*=expDecay;\n");
-
- SET_CURRENT_CONVERTER_CODE("x");
+ SET_SIM_CODE(
+ "injectCurrent(x);\n"
+ "x = (dt * expDecay * inSyn * init) + (expDecay * x);\n"
+ "inSyn *= expDecay;\n");
SET_PARAMS({"tau"});
diff --git a/tests/unit/modelSpecMerged.cc b/tests/unit/modelSpecMerged.cc
index d1de69415..bd0eccd30 100644
--- a/tests/unit/modelSpecMerged.cc
+++ b/tests/unit/modelSpecMerged.cc
@@ -24,12 +24,11 @@ class AlphaCurr : public PostsynapticModels::Base
public:
DECLARE_SNIPPET(AlphaCurr);
- SET_DECAY_CODE(
+ SET_SIM_CODE(
+ "injectCurrent(x);\n"
"x = (dt * expDecay * inSyn * init) + (expDecay * x);\n"
"inSyn *= expDecay;\n");
- SET_CURRENT_CONVERTER_CODE("x");
-
SET_PARAMS({"tau"});
SET_VARS({{"x", "scalar"}});
diff --git a/tests/unit/models.cc b/tests/unit/models.cc
index c84a15ae1..15b5873cb 100644
--- a/tests/unit/models.cc
+++ b/tests/unit/models.cc
@@ -16,12 +16,11 @@ class AlphaCurr : public PostsynapticModels::Base
public:
DECLARE_SNIPPET(AlphaCurr);
- SET_DECAY_CODE(
+ SET_SIM_CODE(
+ "injectCurrent(x);\n"
"x = (dt * expDecay * inSyn * init) + (expDecay * x);\n"
"inSyn *= expDecay;\n");
- SET_CURRENT_CONVERTER_CODE("x");
-
SET_PARAMS({"tau"});
SET_VARS({{"x", "scalar"}});
diff --git a/tests/unit/neuronGroup.cc b/tests/unit/neuronGroup.cc
index d4ab7b840..2609bf881 100644
--- a/tests/unit/neuronGroup.cc
+++ b/tests/unit/neuronGroup.cc
@@ -88,11 +88,10 @@ class AlphaCurr : public PostsynapticModels::Base
public:
DECLARE_SNIPPET(AlphaCurr);
- SET_DECAY_CODE(
+ SET_SIM_CODE(
+ "injectCurrent(x);\n"
"x = (dt * expDecay * inSyn * init) + (expDecay * x);\n"
- "inSyn*=expDecay;\n");
-
- SET_CURRENT_CONVERTER_CODE("x");
+ "inSyn *= expDecay;\n");
SET_PARAMS({"tau"});
diff --git a/tests/unit/postsynapticModels.cc b/tests/unit/postsynapticModels.cc
index 718d49aca..76a31f7b9 100644
--- a/tests/unit/postsynapticModels.cc
+++ b/tests/unit/postsynapticModels.cc
@@ -13,9 +13,9 @@ using namespace GeNN;
class ExpCurrCopy : public PostsynapticModels::Base
{
public:
- SET_DECAY_CODE("inSyn *= expDecay;");
-
- SET_CURRENT_CONVERTER_CODE("init * inSyn");
+ SET_SIM_CODE(
+ "injectCurrent(init * inSyn);\n"
+ "inSyn *= expDecay;\n");
SET_PARAMS({"tau"});