diff --git a/include/genn/genn/postsynapticModels.h b/include/genn/genn/postsynapticModels.h index 888ebd97b..d62c76f83 100644 --- a/include/genn/genn/postsynapticModels.h +++ b/include/genn/genn/postsynapticModels.h @@ -9,9 +9,7 @@ //---------------------------------------------------------------------------- // Macros //---------------------------------------------------------------------------- -#define SET_DECAY_CODE(DECAY_CODE) virtual std::string getDecayCode() const override{ return DECAY_CODE; } -#define SET_CURRENT_CONVERTER_CODE(CURRENT_CONVERTER_CODE) virtual std::string getApplyInputCode() const override{ return "Isyn += " CURRENT_CONVERTER_CODE ";"; } -#define SET_APPLY_INPUT_CODE(APPLY_INPUT_CODE) virtual std::string getApplyInputCode() const override{ return APPLY_INPUT_CODE; } +#define SET_SIM_CODE(SIM_CODE) virtual std::string getSimCode() const override{ return SIM_CODE; } #define SET_NEURON_VAR_REFS(...) virtual VarRefVec getNeuronVarRefs() const override{ return __VA_ARGS__; } //---------------------------------------------------------------------------- @@ -32,8 +30,7 @@ class GENN_EXPORT Base : public Models::Base //! Gets names and types of model variable references virtual VarRefVec getNeuronVarRefs() const{ return {}; } - virtual std::string getDecayCode() const{ return ""; } - virtual std::string getApplyInputCode() const{ return ""; } + virtual std::string getSimCode() const{ return ""; } //---------------------------------------------------------------------------- // Public API @@ -72,8 +69,7 @@ class GENN_EXPORT Init : public Snippet::Init const std::unordered_map &getVarInitialisers() const{ return m_VarInitialisers; } const std::unordered_map &getNeuronVarReferences() const{ return m_NeuronVarReferences; } - const std::vector &getDecayCodeTokens() const{ return m_DecayCodeTokens; } - const std::vector &getApplyInputCodeTokens() const{ return m_ApplyInputCodeTokens; } + const std::vector &getSimCodeTokens() const{ return m_SimCodeTokens; } void finalise(double dt); @@ -81,8 +77,7 @@ class GENN_EXPORT Init : public Snippet::Init //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - std::vector m_DecayCodeTokens; - std::vector m_ApplyInputCodeTokens; + std::vector m_SimCodeTokens; std::unordered_map m_VarInitialisers; std::unordered_map m_NeuronVarReferences; @@ -99,9 +94,9 @@ class ExpCurr : public Base public: DECLARE_SNIPPET(ExpCurr); - SET_DECAY_CODE("inSyn *= expDecay;"); - - SET_CURRENT_CONVERTER_CODE("init * inSyn"); + SET_SIM_CODE( + "injectCurrent(init * inSyn);\n" + "inSyn *= expDecay;\n"); SET_PARAMS({"tau"}); @@ -124,9 +119,9 @@ class ExpCond : public Base public: DECLARE_SNIPPET(ExpCond); - SET_DECAY_CODE("inSyn*=expDecay;"); - - SET_CURRENT_CONVERTER_CODE("inSyn * (E - V)"); + SET_SIM_CODE( + "injectCurrent(inSyn * (E - V));\n" + "inSyn *= expDecay;\n"); SET_PARAMS({"tau", "E"}); @@ -145,6 +140,8 @@ class DeltaCurr : public Base public: DECLARE_SNIPPET(DeltaCurr); - SET_CURRENT_CONVERTER_CODE("inSyn; inSyn = 0"); + SET_SIM_CODE( + "injectCurrent(inSyn);\n" + "inSyn = 0.0;\n"); }; } // namespace GeNN::PostsynapticModels diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py index 0754a1c59..96b7f23ed 100644 --- a/pygenn/genn_model.py +++ b/pygenn/genn_model.py @@ -1020,8 +1020,8 @@ def create_neuron_model(class_name, params=None, param_names=None, def create_postsynaptic_model(class_name, params=None, param_names=None, var_name_types=None, neuron_var_refs=None, - derived_params=None, decay_code=None, - apply_input_code=None, + derived_params=None, sim_code=None, + decay_code=None, apply_input_code=None, extra_global_params=None): """This helper function creates a custom PostsynapticModel class. See also: @@ -1042,21 +1042,21 @@ def create_postsynaptic_model(class_name, params=None, param_names=None, derived_params -- list of pairs, where the first member is string with name of the derived parameter and the second should be a functor returned by create_dpf_class + sim_code -- string with the decay code decay_code -- string with the decay code apply_input_code -- string with the apply input code extra_global_params -- list of pairs of strings with names and types of additional parameters """ body = {} - - if decay_code is not None: - body["get_decay_code"] =\ - lambda self: dedent(upgrade_code_string(decay_code, class_name)) - - if apply_input_code is not None: - body["get_apply_input_code"] =\ - lambda self: dedent(upgrade_code_string(apply_input_code, - class_name)) + if decay_code is not None or apply_input_code is not None: + raise RuntimeError("Creating postsynaptic models with seperate " + "'decay_code' and 'apply_code' code strings is no " + "longer supported. Please provide 'sim_code' using " + "the injectCurrent(X) function to provide input.") + if sim_code is not None: + body["get_sim_code"] =\ + lambda self: dedent(upgrade_code_string(sim_code, class_name)) if var_name_types is not None: body["get_vars"] = \ diff --git a/pygenn/src/genn.cc b/pygenn/src/genn.cc index 370e7ce12..4641fcc47 100644 --- a/pygenn/src/genn.cc +++ b/pygenn/src/genn.cc @@ -173,8 +173,7 @@ class PyPostsynapticModelBase : public PySnippet virtual std::vector getVars() const override{ PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_vars", getVars); } virtual VarRefVec getNeuronVarRefs() const override { PYBIND11_OVERRIDE_NAME(VarRefVec, Base, "get_neuron_var_refs", getNeuronVarRefs); } - virtual std::string getDecayCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_decay_code", getDecayCode); } - virtual std::string getApplyInputCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_apply_input_code", getApplyInputCode); } + virtual std::string getSimCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_sim_code", getSimCode); } }; //---------------------------------------------------------------------------- @@ -846,7 +845,7 @@ PYBIND11_MODULE(genn, m) .def("get_additional_input_vars", &NeuronModels::Base::getAdditionalInputVars) .def("is_auto_refractory_required", &NeuronModels::Base::isAutoRefractoryRequired); - + //------------------------------------------------------------------------ // genn.PostsynapticModelBase //------------------------------------------------------------------------ @@ -856,9 +855,8 @@ PYBIND11_MODULE(genn, m) .def("get_vars", &PostsynapticModels::Base::getVars) .def("get_neuron_var_refs", &PostsynapticModels::Base::getNeuronVarRefs) - .def("get_decay_code", &PostsynapticModels::Base::getDecayCode) - .def("get_apply_input_code", &PostsynapticModels::Base::getApplyInputCode); - + .def("get_sim_code", &PostsynapticModels::Base::getSimCode); + //------------------------------------------------------------------------ // genn.WeightUpdateModelBase //------------------------------------------------------------------------ diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 619303d4e..9ec32e279 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -120,9 +120,10 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env // **TODO** naming convention psmEnv.add(getScalarType(), "inSyn", "linSyn"); - - // Allow synapse group's PS output var to override what Isyn points to - psmEnv.add(getScalarType(), "Isyn", "$(_" + getArchetype().getPostTargetVar() + ")"); + + // Define inject current function + psmEnv.add(Type::ResolvedType::createFunction(Type::Void, {getScalarType()}), + "injectCurrent", "$(_" + getArchetype().getPostTargetVar() + ") += $(0)"); // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( @@ -133,11 +134,8 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env }); // Pretty print code back to environment - Transpiler::ErrorHandler applyInputErrorHandler("Synapse group '" + getArchetype().getName() + "' postsynaptic model apply input code"); - prettyPrintStatements(getArchetype().getPSInitialiser().getApplyInputCodeTokens(), getTypeContext(), varEnv, applyInputErrorHandler); - - Transpiler::ErrorHandler decayErrorHandler("Synapse group '" + getArchetype().getName() + "' postsynaptic model decay code"); - prettyPrintStatements(getArchetype().getPSInitialiser().getDecayCodeTokens(), getTypeContext(), varEnv, decayErrorHandler); + Transpiler::ErrorHandler errorHandler("Synapse group '" + getArchetype().getName() + "' postsynaptic model apply sim code"); + prettyPrintStatements(getArchetype().getPSInitialiser().getSimCodeTokens(), getTypeContext(), varEnv, errorHandler); // Write back linSyn varEnv.printLine("$(_out_post)[" + ng.getVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, "$(id)") + "] = linSyn;"); diff --git a/src/genn/genn/neuronGroup.cc b/src/genn/genn/neuronGroup.cc index 15da10fd9..249efddee 100644 --- a/src/genn/genn/neuronGroup.cc +++ b/src/genn/genn/neuronGroup.cc @@ -336,8 +336,7 @@ bool NeuronGroup::isSimRNGRequired() const return std::any_of(getInSyn().cbegin(), getInSyn().cend(), [](const SynapseGroupInternal *sg) { - return (Utils::isRNGRequired(sg->getPSInitialiser().getApplyInputCodeTokens()) || - Utils::isRNGRequired(sg->getPSInitialiser().getDecayCodeTokens())); + return Utils::isRNGRequired(sg->getPSInitialiser().getSimCodeTokens()); }); } //---------------------------------------------------------------------------- diff --git a/src/genn/genn/postsynapticModels.cc b/src/genn/genn/postsynapticModels.cc index fe131e61c..b5ada146e 100644 --- a/src/genn/genn/postsynapticModels.cc +++ b/src/genn/genn/postsynapticModels.cc @@ -22,8 +22,7 @@ boost::uuids::detail::sha1::digest_type Base::getHashDigest() const Snippet::Base::updateHash(hash); Utils::updateHash(getVars(), hash); Utils::updateHash(getNeuronVarRefs(), hash); - Utils::updateHash(getDecayCode(), hash); - Utils::updateHash(getApplyInputCode(), hash); + Utils::updateHash(getSimCode(), hash); return hash.get_digest(); } //---------------------------------------------------------------------------- @@ -60,13 +59,12 @@ Init::Init(const Base *snippet, const std::unordered_mapgetNeuronVarRefs()); // Scan code tokens - m_DecayCodeTokens = Utils::scanCode(getSnippet()->getDecayCode(), "Postsynaptic model decay code"); - m_ApplyInputCodeTokens = Utils::scanCode(getSnippet()->getApplyInputCode(), "Postsynaptic model apply input code"); + m_SimCodeTokens = Utils::scanCode(getSnippet()->getSimCode(), "Postsynaptic model sim code"); } //---------------------------------------------------------------------------- bool Init::isRNGRequired() const { - return (Utils::isRNGRequired(m_DecayCodeTokens) || Utils::isRNGRequired(m_ApplyInputCodeTokens)); + return Utils::isRNGRequired(m_SimCodeTokens); } //---------------------------------------------------------------------------- bool Init::isVarInitRequired() const diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index 43754d8e2..023aca191 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -575,13 +575,8 @@ bool SynapseGroup::canPSBeFused(const NeuronGroup *ng) const // **NOTE** this is kind of silly as, if it's not referenced in either of // these code strings, there wouldn't be a lot of point in a PSM EGP existing! for(const auto &egp : getPSInitialiser().getSnippet()->getExtraGlobalParams()) { - // If this EGP is referenced in decay code, return false - if(Utils::isIdentifierReferenced(egp.name, getPSInitialiser().getDecayCodeTokens())) { - return false; - } - - // If this EGP is referenced in apply input code, return false - if(Utils::isIdentifierReferenced(egp.name, getPSInitialiser().getApplyInputCodeTokens())) { + // If this EGP is referenced in sim code, return false + if(Utils::isIdentifierReferenced(egp.name, getPSInitialiser().getSimCodeTokens())) { return false; } } @@ -590,13 +585,8 @@ bool SynapseGroup::canPSBeFused(const NeuronGroup *ng) const for(const auto &p : getPSInitialiser().getSnippet()->getParams()) { // If parameter is dynamic if(isPSParamDynamic(p.name)) { - // If this parameter is referenced in decay code, return false - if(Utils::isIdentifierReferenced(p.name, getPSInitialiser().getDecayCodeTokens())) { - return false; - } - - // If this parameter is referenced in apply input code, return false - if(Utils::isIdentifierReferenced(p.name, getPSInitialiser().getApplyInputCodeTokens())) { + // If this parameter is referenced in sim code, return false + if(Utils::isIdentifierReferenced(p.name, getPSInitialiser().getSimCodeTokens())) { return false; } } diff --git a/tests/features/test_dynamic_param.py b/tests/features/test_dynamic_param.py index 58b40007b..1775ee062 100644 --- a/tests/features/test_dynamic_param.py +++ b/tests/features/test_dynamic_param.py @@ -28,15 +28,12 @@ def test_dynamic_param(make_model, backend, precision): postsynaptic_model = create_postsynaptic_model( "postsynaptic", - decay_code= + sim_code= """ + injectCurrent(inSyn); psmX = t + psmShift + psmInput; $(inSyn) = 0; """, - apply_input_code= - """ - $(Isyn) += $(inSyn); - """, params=["psmInput"], var_name_types=[("psmX", "scalar"), ("psmShift", "scalar")]) diff --git a/tests/unit/modelSpec.cc b/tests/unit/modelSpec.cc index 7dbeb0137..922fdd687 100644 --- a/tests/unit/modelSpec.cc +++ b/tests/unit/modelSpec.cc @@ -16,11 +16,10 @@ class AlphaCurr : public PostsynapticModels::Base public: DECLARE_SNIPPET(AlphaCurr); - SET_DECAY_CODE( - "x = (DT * expDecay * inSyn * init) + (expDecay * x);\n" - "inSyn*=expDecay;\n"); - - SET_CURRENT_CONVERTER_CODE("x"); + SET_SIM_CODE( + "injectCurrent(x);\n" + "x = (dt * expDecay * inSyn * init) + (expDecay * x);\n" + "inSyn *= expDecay;\n"); SET_PARAMS({"tau"}); diff --git a/tests/unit/modelSpecMerged.cc b/tests/unit/modelSpecMerged.cc index d1de69415..bd0eccd30 100644 --- a/tests/unit/modelSpecMerged.cc +++ b/tests/unit/modelSpecMerged.cc @@ -24,12 +24,11 @@ class AlphaCurr : public PostsynapticModels::Base public: DECLARE_SNIPPET(AlphaCurr); - SET_DECAY_CODE( + SET_SIM_CODE( + "injectCurrent(x);\n" "x = (dt * expDecay * inSyn * init) + (expDecay * x);\n" "inSyn *= expDecay;\n"); - SET_CURRENT_CONVERTER_CODE("x"); - SET_PARAMS({"tau"}); SET_VARS({{"x", "scalar"}}); diff --git a/tests/unit/models.cc b/tests/unit/models.cc index c84a15ae1..15b5873cb 100644 --- a/tests/unit/models.cc +++ b/tests/unit/models.cc @@ -16,12 +16,11 @@ class AlphaCurr : public PostsynapticModels::Base public: DECLARE_SNIPPET(AlphaCurr); - SET_DECAY_CODE( + SET_SIM_CODE( + "injectCurrent(x);\n" "x = (dt * expDecay * inSyn * init) + (expDecay * x);\n" "inSyn *= expDecay;\n"); - SET_CURRENT_CONVERTER_CODE("x"); - SET_PARAMS({"tau"}); SET_VARS({{"x", "scalar"}}); diff --git a/tests/unit/neuronGroup.cc b/tests/unit/neuronGroup.cc index d4ab7b840..2609bf881 100644 --- a/tests/unit/neuronGroup.cc +++ b/tests/unit/neuronGroup.cc @@ -88,11 +88,10 @@ class AlphaCurr : public PostsynapticModels::Base public: DECLARE_SNIPPET(AlphaCurr); - SET_DECAY_CODE( + SET_SIM_CODE( + "injectCurrent(x);\n" "x = (dt * expDecay * inSyn * init) + (expDecay * x);\n" - "inSyn*=expDecay;\n"); - - SET_CURRENT_CONVERTER_CODE("x"); + "inSyn *= expDecay;\n"); SET_PARAMS({"tau"}); diff --git a/tests/unit/postsynapticModels.cc b/tests/unit/postsynapticModels.cc index 718d49aca..76a31f7b9 100644 --- a/tests/unit/postsynapticModels.cc +++ b/tests/unit/postsynapticModels.cc @@ -13,9 +13,9 @@ using namespace GeNN; class ExpCurrCopy : public PostsynapticModels::Base { public: - SET_DECAY_CODE("inSyn *= expDecay;"); - - SET_CURRENT_CONVERTER_CODE("init * inSyn"); + SET_SIM_CODE( + "injectCurrent(init * inSyn);\n" + "inSyn *= expDecay;\n"); SET_PARAMS({"tau"});