From 0d570431c10eac72dfa635832e64075b6fd6806a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 16 Aug 2023 17:52:51 +0100 Subject: [PATCH 01/60] define new access types based on dimensions --- include/genn/genn/varAccess.h | 82 ++++++++++++++++++++++++++++++----- 1 file changed, 72 insertions(+), 10 deletions(-) diff --git a/include/genn/genn/varAccess.h b/include/genn/genn/varAccess.h index 43b14378cc..2c1cd56022 100644 --- a/include/genn/genn/varAccess.h +++ b/include/genn/genn/varAccess.h @@ -33,6 +33,68 @@ enum class VarAccessDuplication : unsigned int SHARED_NEURON = (1 << 7) //! This variable should be shared between neurons }; +//! Flags defining dimensions this variables has +enum class VarAccessDim : unsigned int +{ + NEURON = (1 << 5), + PRE_NEURON = (1 << 6), + POST_NEURON = (1 << 7), + DELAY = (1 << 8), + BATCH = (1 << 9), +}; + +//! Supported combinations of access mode and dimension for neuron variables +enum class NeuronVarAccess : unsigned int +{ + READ_WRITE = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDim::NEURON) | static_cast(VarAccessDim::BATCH), + READ_ONLY = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::NEURON) | static_cast(VarAccessDim::BATCH), + READ_ONLY_DUPLICATE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::NEURON), + READ_ONLY_SHARED_NEURON = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::BATCH), +}; + +//! Supported combinations of access mode and dimension for synapse variables +/*enum class SynapseVarAccess : unsigned int +{ + // Synaptic variables + READ_WRITE = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDim::PRE_NEURON) | static_cast(VarAccessDim::POST_NEURON) | static_cast(VarAccessDim::BATCH), + READ_ONLY = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::PRE_NEURON) | static_cast(VarAccessDim::POST_NEURON), + READ_ONLY_DUPLICATE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::PRE_NEURON) | static_cast(VarAccessDim::POST_NEURON) | static_cast(VarAccessDim::BATCH), + + // Presynaptic variables + READ_WRITE_PRE = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDim::PRE_NEURON) | static_cast(VarAccessDim::BATCH), + READ_ONLY_PRE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::PRE_NEURON), + READ_ONLY_PRE_DUPLICATE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::PRE_NEURON) | static_cast(VarAccessDim::BATCH), + + // Postsynaptic variables + READ_WRITE_POST = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDim::POST_NEURON) | static_cast(VarAccessDim::BATCH), + READ_ONLY_POST = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::POST_NEURON), + READ_ONLY_POST_DUPLICATE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::POST_NEURON) | static_cast(VarAccessDim::BATCH), +}; + +enum class CustomUpdateVarAccess : unsigned int +{ + // Variables with matching shape + READ_WRITE, + READ_ONLY, + + // Variables shared across batches + READ_WRITE_SHARED, + READ_ONLY_SHARED, + + + READ_WRITE_PRE, + + // Reduction variables + REDUCE_BATCH_SUM, + REDUCE_BATCH_MAX, + REDUCE_NEURON_SUM, + REDUCE_NEURON_MAX, + REDUCE_PRE_NEURON_SUM, + REDUCE_PRE_NEURON_MAX, + REDUCE_POST_NEURON_SUM, + REDUCE_POST_NEURON_MAX, +}*/ + //! Supported combinations of VarAccessMode and VarAccessDuplication enum class VarAccess : unsigned int { @@ -49,19 +111,19 @@ enum class VarAccess : unsigned int //---------------------------------------------------------------------------- // Operators //---------------------------------------------------------------------------- -inline bool operator & (VarAccess type, VarAccessMode mode) +inline bool operator & (unsigned int type, VarAccessMode mode) { - return (static_cast(type) & static_cast(mode)) != 0; + return (type & static_cast(mode)) != 0; } -inline bool operator & (VarAccess type, VarAccessDuplication duplication) +inline bool operator & (unsigned int type, VarAccessDuplication duplication) { - return (static_cast(type) & static_cast(duplication)) != 0; + return (type & static_cast(duplication)) != 0; } -inline bool operator & (VarAccess type, VarAccessModeAttribute modeAttribute) +inline bool operator & (unsigned int type, VarAccessModeAttribute modeAttribute) { - return (static_cast(type) & static_cast(modeAttribute)) != 0; + return (type & static_cast(modeAttribute)) != 0; } inline bool operator & (VarAccessMode mode, VarAccessModeAttribute modeAttribute) @@ -78,13 +140,13 @@ inline bool operator & (VarAccessMode a, VarAccessMode b) //---------------------------------------------------------------------------- // Helpers //---------------------------------------------------------------------------- -inline VarAccessMode getVarAccessMode(VarAccess type) +inline VarAccessMode getVarAccessMode(unsigned int type) { - return static_cast(static_cast(type) & 0x1F); + return static_cast(type & 0x1F); } -inline VarAccessDuplication getVarAccessDuplication(VarAccess type) +inline VarAccessDuplication getVarAccessDuplication(unsigned int type) { - return static_cast(static_cast(type) & ~0x1F); + return static_cast(type & ~0x1F); } } // namespace GeNN From 6be6abd6fdd14488814176c70c382e8f0d7eccf3 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 16 Aug 2023 17:54:38 +0100 Subject: [PATCH 02/60] started refactoring variable access -Made ``Models::Base::Var::acccess`` a ``std::optional`` so ``VarAccess`` enum is only for API purposesand started inserting default access types at callsites --- .../genn/genn/code_generator/backendBase.h | 7 +++-- .../genn/genn/code_generator/environment.h | 23 +++++++------- include/genn/genn/customUpdate.h | 11 ++++--- include/genn/genn/models.h | 26 +++++++++++++--- .../backends/single_threaded_cpu/backend.cc | 6 ++-- src/genn/genn/code_generator/backendBase.cc | 6 ++-- .../genn/code_generator/generateRunner.cc | 6 ++-- .../synapseUpdateGroupMerged.cc | 8 ++--- src/genn/genn/currentSourceModels.cc | 6 +++- src/genn/genn/customConnectivityUpdate.cc | 12 ++++++-- .../genn/customConnectivityUpdateModels.cc | 30 ++++++++++++++----- src/genn/genn/customUpdate.cc | 18 +++++++---- src/genn/genn/models.cc | 10 +++++-- src/genn/genn/neuronModels.cc | 6 +++- src/genn/genn/postsynapticModels.cc | 6 +++- src/genn/genn/weightUpdateModels.cc | 24 ++++++++++++--- 16 files changed, 147 insertions(+), 58 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 07ac42c980..215afe8e1d 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -568,11 +568,12 @@ class GENN_EXPORT BackendBase const auto *cm = cg.getArchetype().getCustomUpdateModel(); for (const auto &v : cm->getVars()) { // If variable is a reduction target, define variable initialised to correct initial value for reduction - if (v.access & VarAccessModeAttribute::REDUCE) { + const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); + if (varAccess & VarAccessModeAttribute::REDUCE) { const auto resolvedType = v.type.resolve(cg.getTypeContext()); os << resolvedType.getName() << " _lr" << v.name << " = " << getReductionInitialValue(getVarAccessMode(v.access), resolvedType) << ";" << std::endl; - reductionTargets.push_back({v.name, resolvedType, getVarAccessMode(v.access), - cg.getVarIndex(getVarAccessDuplication(v.access), idx)}); + reductionTargets.push_back({v.name, resolvedType, getVarAccessMode(varAccess), + cg.getVarIndex(getVarAccessDuplication(varAccess), idx)}); } } diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 3b380ed126..e5b5beb6b9 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -405,7 +405,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase &(GroupInternal::*)(void) const; - using GetVarIndexFn = std::function; + using GetVarIndexFn = std::function, const std::string&)>; template using GetVarRefIndexFn = std::function; @@ -653,7 +653,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase(arrayPrefix, [&indexSuffix](VarAccess, const std::string &) { return indexSuffix; }, + addVars(arrayPrefix, [&indexSuffix](std::optional, const std::string &) { return indexSuffix; }, fieldSuffix, readOnly); } @@ -682,7 +682,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase(arrayPrefix, [&indexSuffix](VarAccess a, auto &) { return indexSuffix; }, + addVarRefs(arrayPrefix, [&indexSuffix](std::optional, auto &) { return indexSuffix; }, fieldSuffix); } @@ -722,22 +722,23 @@ class VarCachePolicy //------------------------------------------------------------------------ bool shouldAlwaysCopy(G&, const Models::Base::Var &var) const { - if(m_ShouldAlwaysCopy) { - return m_ShouldAlwaysCopy(var.name, getVarAccessDuplication(var.access)); - } - else { - return false; - } + // **TODO** default from InitModel class + const unsigned int varAccess = var.access.value_or(static_cast(VarAccess::READ_WRITE)); + return m_ShouldAlwaysCopy(var.name, getVarAccessDuplication(varAccess)); } std::string getReadIndex(G&, const Models::Base::Var &var) const { - return m_GetReadIndex(var.name, getVarAccessDuplication(var.access)); + // **TODO** default from InitModel class + const unsigned int varAccess = var.access.value_or(static_cast(VarAccess::READ_WRITE)); + return m_GetReadIndex(var.name, getVarAccessDuplication(varAccess)); } std::string getWriteIndex(G&, const Models::Base::Var &var) const { - return m_GetWriteIndex(var.name, getVarAccessDuplication(var.access)); + // **TODO** default from InitModel class + const unsigned int varAccess = var.access.value_or(static_cast(VarAccess::READ_WRITE)); + return m_GetWriteIndex(var.name, getVarAccessDuplication(varAccess)); } std::string getTargetName(const GroupInternal &g, const Models::Base::Var &var) const diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index bf907a75b9..242789c5de 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -93,7 +93,8 @@ class GENN_EXPORT CustomUpdateBase if(std::any_of(vars.cbegin(), vars.cend(), [duplication](const Models::Base::Var &v) { - return (v.access & VarAccessModeAttribute::REDUCE) && (v.access & duplication); + const unsigned int access = v.access.value_or(static_cast(VarAccess::READ_WRITE)); + return (access & VarAccessModeAttribute::REDUCE) && (access & duplication); })) { return true; @@ -101,9 +102,10 @@ class GENN_EXPORT CustomUpdateBase // Loop through all variable references for(const auto &modelVarRef : getCustomUpdateModel()->getVarRefs()) { - const auto &varRef = varRefs.at(modelVarRef.name); // If custom update model reduces into this variable reference and the variable it targets has correct duplication flag - if ((modelVarRef.access & VarAccessModeAttribute::REDUCE) & (varRef.getVar().access & duplication)) { + const auto &varRef = varRefs.at(modelVarRef.name); + const unsigned int varAccess = varRef.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); + if ((modelVarRef.access & VarAccessModeAttribute::REDUCE) & (varAccess & duplication)) { return true; } } @@ -130,7 +132,8 @@ class GENN_EXPORT CustomUpdateBase // If custom update is batched, check that any variable references to shared variables are read-only // **NOTE** if custom update isn't batched, it's totally fine to write to shared variables - if(m_Batched && (varRef.getVar().access & VarAccessDuplication::SHARED) + const unsigned int varAccess = varRef.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); + if(m_Batched && (varAccess & VarAccessDuplication::SHARED) && (modelVarRef.access == VarAccessMode::READ_WRITE)) { throw std::runtime_error("Variable references to SHARED variables in batched custom updates cannot be read-write."); diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index c0cfe0f849..1107df12ff 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -54,10 +54,26 @@ class GENN_EXPORT Base : public Snippet::Base if not specified, this results in a -Wmissing-field-initializers warning on GCC and Clang*/ struct GENN_EXPORT Var { - Var(const std::string &n, const Type::ResolvedType &t, VarAccess a = VarAccess::READ_WRITE) : name(n), type(t), access(a) + Var(const std::string &n, const Type::ResolvedType &t) + : name(n), type(t) {} - Var(const std::string &n, const std::string &t, VarAccess a = VarAccess::READ_WRITE) : name(n), type(t), access(a) + Var(const std::string &n, const std::string &t) + : name(n), type(t) {} + + Var(const std::string &n, const Type::ResolvedType &t, VarAccess a) + : name(n), type(t), access(static_cast(a)) + {} + Var(const std::string &n, const std::string &t, VarAccess a) + : name(n), type(t), access(static_cast(a)) + {} + + /*Var(const std::string &n, const Type::ResolvedType &t, NeuronVarAccess a) + : name(n), type(t), access(static_cast(a)) + {} + Var(const std::string &n, const std::string &t, NeuronVarAccess a) + : name(n), type(t), access(static_cast(a)) + {}*/ bool operator == (const Var &other) const { @@ -66,7 +82,7 @@ class GENN_EXPORT Base : public Snippet::Base std::string name; Type::UnresolvedType type; - VarAccess access; + std::optional access; }; struct GENN_EXPORT VarRef @@ -389,7 +405,9 @@ void checkVarReferences(const std::unordered_map &varRefs, const } // Check that no reduction targets reference duplicated variables - if((varRef.getVar().access & VarAccessDuplication::DUPLICATE) + // **TODO** default from InitModel class + const unsigned int varAccess = varRef.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); + if((varAccess & VarAccessDuplication::DUPLICATE) && (modelVarRef.access & VarAccessModeAttribute::REDUCE)) { throw std::runtime_error("Reduction target variable reference must be to SHARED or SHARED_NEURON variables."); diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 1beec92efc..c3e7eb9e0e 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -2016,8 +2016,9 @@ void Backend::genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateG genWriteBackReductions(env, cg, idxName, [&cg](const Models::VarReference &varRef, const std::string &index) { + const unsigned int varAccess = varRef.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); return cg.getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, - getVarAccessDuplication(varRef.getVar().access), + getVarAccessDuplication(varAccess), index); }); } @@ -2027,7 +2028,8 @@ void Backend::genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateW genWriteBackReductions(env, cg, idxName, [&cg](const Models::WUVarReference &varRef, const std::string &index) { - return cg.getVarRefIndex(getVarAccessDuplication(varRef.getVar().access), + const unsigned int varAccess = varRef.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); + return cg.getVarRefIndex(getVarAccessDuplication(varAccess), index); }); } diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 08bdc72a19..e3a27dba43 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -702,8 +702,9 @@ std::vector BackendBase::genInitReductionTargets(C return genInitReductionTargets(os, cg, idx, [&cg](const Models::VarReference &varRef, const std::string &index) { + const unsigned int varAccess = varRef.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); return cg.getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, - getVarAccessDuplication(varRef.getVar().access), + getVarAccessDuplication(varAccess), index); }); } @@ -713,7 +714,8 @@ std::vector BackendBase::genInitReductionTargets(C return genInitReductionTargets(os, cg, idx, [&cg](const Models::WUVarReference &varRef, const std::string &index) { - return cg.getVarRefIndex(getVarAccessDuplication(varRef.getVar().access), + const unsigned int varAccess = varRef.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); + return cg.getVarRefIndex(getVarAccessDuplication(varAccess), index); }); } diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index ec7563f312..f4bb1a15af 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -27,17 +27,17 @@ using namespace GeNN::CodeGenerator; //-------------------------------------------------------------------------- namespace { -unsigned int getNumVarCopies(VarAccess varAccess, unsigned int batchSize, bool batched = true) +unsigned int getNumVarCopies(unsigned int varAccess, unsigned int batchSize, bool batched = true) { return ((varAccess & VarAccessDuplication::SHARED) || !batched) ? 1 : batchSize; } //-------------------------------------------------------------------------- -unsigned int getNumVarElements(VarAccess varAccess, unsigned int numNeurons) +unsigned int getNumVarElements(unsigned int varAccess, unsigned int numNeurons) { return (varAccess & VarAccessDuplication::SHARED_NEURON) ? 1 : numNeurons; } //-------------------------------------------------------------------------- -unsigned int getVarSize(VarAccess varAccess, unsigned int numElements, unsigned int batchSize, +unsigned int getVarSize(unsigned int varAccess, unsigned int numElements, unsigned int batchSize, unsigned int delaySlots = 1, bool batched = true) { return getNumVarCopies(varAccess, batchSize, batched) * getNumVarElements(varAccess, numElements) * delaySlots; diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index a1577e0096..264285061a 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -30,13 +30,13 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa // Substitute names of pre and postsynaptic weight update variable synEnv.template addVars( backend.getDeviceVarPrefix(), - [&sg, batchSize](VarAccess a, const std::string&) + [&sg, batchSize](unsigned int a, const std::string&) { return sg.getPreWUVarIndex(batchSize, getVarAccessDuplication(a), "$(id_pre)"); }, "", true); synEnv.template addVars( backend.getDeviceVarPrefix(), - [&sg, batchSize](VarAccess a, const std::string&) + [&sg, batchSize](unsigned int a, const std::string&) { return sg.getPostWUVarIndex(batchSize, getVarAccessDuplication(a), "$(id_post)"); }, "", true); @@ -78,7 +78,7 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa if (sg.getArchetype().getMatrixType() & SynapseMatrixWeight::INDIVIDUAL) { synEnv.template addVars( backend.getDeviceVarPrefix(), - [&sg, batchSize](VarAccess a, const std::string&) + [&sg, batchSize](unsigned int a, const std::string&) { return sg.getSynVarIndex(batchSize, getVarAccessDuplication(a), "$(id_syn)"); }); @@ -121,7 +121,7 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa synEnv.template addVars( backend.getDeviceVarPrefix(), - [&sg, batchSize](VarAccess a, const std::string&) + [&sg, batchSize](unsigned int a, const std::string&) { return sg.getKernelVarIndex(batchSize, getVarAccessDuplication(a), "$(id_kernel)"); }); diff --git a/src/genn/genn/currentSourceModels.cc b/src/genn/genn/currentSourceModels.cc index d45b772c57..8f650a5db4 100644 --- a/src/genn/genn/currentSourceModels.cc +++ b/src/genn/genn/currentSourceModels.cc @@ -35,7 +35,11 @@ void Base::validate(const std::unordered_map ¶mValues, // If any variables have a reduction access mode, give an error const auto vars = getVars(); if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v){ return (v.access & VarAccessModeAttribute::REDUCE); })) + [](const Models::Base::Var &v) + { + const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); + return (varAccess & VarAccessModeAttribute::REDUCE); + })) { throw std::runtime_error("Current source models cannot include variables with REDUCE access modes - they are only supported by custom update models"); } diff --git a/src/genn/genn/customConnectivityUpdate.cc b/src/genn/genn/customConnectivityUpdate.cc index f82dfe68a5..d55c598678 100644 --- a/src/genn/genn/customConnectivityUpdate.cc +++ b/src/genn/genn/customConnectivityUpdate.cc @@ -178,14 +178,22 @@ void CustomConnectivityUpdate::finalise(double dt, unsigned int batchSize) if (batchSize > 1) { // If any referenced presynaptic variables aren't shared, give error if (std::any_of(getPreVarReferences().cbegin(), getPreVarReferences().cend(), - [](const auto &v) { return (getVarAccessDuplication(v.second.getVar().access) != VarAccessDuplication::SHARED); })) + [](const auto &v) + { + const unsigned int varAccess = v.second.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); + return (getVarAccessDuplication(varAccess) != VarAccessDuplication::SHARED); + })) { throw std::runtime_error("Presynaptic variables referenced by CustomConnectivityUpdate must be SHARED across batches"); } // If any referenced presynaptic variables aren't shared, give error if (std::any_of(getPostVarReferences().cbegin(), getPostVarReferences().cend(), - [](const auto &v) { return (getVarAccessDuplication(v.second.getVar().access) != VarAccessDuplication::SHARED); })) + [](const auto &v) + { + const unsigned int varAccess = v.second.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); + return (getVarAccessDuplication(varAccess) != VarAccessDuplication::SHARED); + })) { throw std::runtime_error("Postsynaptic variables referenced by CustomConnectivityUpdate must be SHARED across batches"); } diff --git a/src/genn/genn/customConnectivityUpdateModels.cc b/src/genn/genn/customConnectivityUpdateModels.cc index 51292a4584..48ab5b5fb4 100644 --- a/src/genn/genn/customConnectivityUpdateModels.cc +++ b/src/genn/genn/customConnectivityUpdateModels.cc @@ -60,12 +60,24 @@ void Base::validate(const std::unordered_map ¶mValues, // If any variables have a reduction access mode, give an error // **YUCK** copy-paste from WUM - could go in helper/Models::Base - if (std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v) { return (v.access & VarAccessModeAttribute::REDUCE); }) - || std::any_of(preVars.cbegin(), preVars.cend(), - [](const Models::Base::Var &v) { return (v.access & VarAccessModeAttribute::REDUCE); }) - || std::any_of(postVars.cbegin(), postVars.cend(), - [](const Models::Base::Var &v) { return (v.access & VarAccessModeAttribute::REDUCE); })) + if(std::any_of(vars.cbegin(), vars.cend(), + [](const Models::Base::Var &v) + { + const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); + return (varAccess & VarAccessModeAttribute::REDUCE); + }) + || std::any_of(preVars.cbegin(), preVars.cend(), + [](const Models::Base::Var &v) + { + const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); + return (varAccess & VarAccessModeAttribute::REDUCE); + }) + || std::any_of(postVars.cbegin(), postVars.cend(), + [](const Models::Base::Var &v) + { + const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); + return (varAccess & VarAccessModeAttribute::REDUCE); + })) { throw std::runtime_error("Custom connectivity update models cannot include variables with REDUCE access modes - they are only supported by custom update models"); } @@ -73,7 +85,11 @@ void Base::validate(const std::unordered_map ¶mValues, // If any variables have shared neuron duplication mode, give an error // **YUCK** copy-paste from WUM - could go in helper/Models::Base if (std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v) { return (v.access & VarAccessDuplication::SHARED_NEURON); })) + [](const Models::Base::Var &v) + { + const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); + return (varAccess & VarAccessDuplication::SHARED_NEURON); + })) { throw std::runtime_error("Custom connectivity update models cannot include variables with SHARED_NEURON access modes - they are only supported on pre, postsynaptic or neuron variables"); } diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index aa91135862..0d15cecca7 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -139,12 +139,14 @@ CustomUpdate::CustomUpdate(const std::string &name, const std::string &updateGro m_PerNeuron = std::any_of(m_VarReferences.cbegin(), m_VarReferences.cend(), [](const auto& v) { - return !(v.second.getVar().access & VarAccessDuplication::SHARED_NEURON); + const unsigned int varAccess = v.second.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); + return !(varAccess & VarAccessDuplication::SHARED_NEURON); }); m_PerNeuron |= std::any_of(modelVars.cbegin(), modelVars.cend(), [](const Models::Base::Var& v) { - return !(v.access & VarAccessDuplication::SHARED_NEURON); + const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); + return !(varAccess & VarAccessDuplication::SHARED_NEURON); }); // Loop through all variable references @@ -153,7 +155,8 @@ CustomUpdate::CustomUpdate(const std::string &name, const std::string &updateGro // If custom update is per-neuron, check that any variable references to SHARED_NEURON variables are read-only // **NOTE** if custom update isn't per-neuron, it's totally fine to write to SHARED_NEURON variables - if(m_PerNeuron && (varRef.getVar().access & VarAccessDuplication::SHARED_NEURON) + const unsigned int varAccess = varRef.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); + if(m_PerNeuron && (varAccess & VarAccessDuplication::SHARED_NEURON) && (modelVarRef.access == VarAccessMode::READ_WRITE)) { throw std::runtime_error("Variable references to SHARED_NEURON variables in per-neuron custom updates cannot be read-write."); @@ -219,7 +222,8 @@ boost::uuids::detail::sha1::digest_type CustomUpdate::getHashDigest() const Utils::updateHash((v.second.getDelayNeuronGroup() == nullptr), hash); // Update hash with duplication mode of target variable as this effects indexing code - Utils::updateHash(getVarAccessDuplication(v.second.getVar().access), hash); + const unsigned int varAccess = v.second.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); + Utils::updateHash(getVarAccessDuplication(varAccess), hash); } return hash.get_digest(); } @@ -271,7 +275,8 @@ CustomUpdateWU::CustomUpdateWU(const std::string &name, const std::string &updat if (std::any_of(vars.cbegin(), vars.cend(), [](const Models::Base::Var &v) { - return (v.access & VarAccessDuplication::SHARED_NEURON); + const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); + return (varAccess & VarAccessDuplication::SHARED_NEURON); })) { throw std::runtime_error("Custom weight updates cannot use models with SHARED_NEURON variables."); @@ -341,7 +346,8 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWU::getHashDigest() const Utils::updateHash((v.second.getTransposeSynapseGroup() == nullptr), hash); // Update hash with duplication mode of target variable as this effects indexing code - Utils::updateHash(getVarAccessDuplication(v.second.getVar().access), hash); + const unsigned int varAccess = v.second.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); + Utils::updateHash(getVarAccessDuplication(varAccess), hash); } return hash.get_digest(); diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index 916a234267..c5d8a5ca77 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -75,7 +75,8 @@ std::string VarReference::getTargetName() const //---------------------------------------------------------------------------- bool VarReference::isDuplicated() const { - if(getVar().access & VarAccessDuplication::SHARED) { + const unsigned int varAccess = getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); + if(varAccess & VarAccessDuplication::SHARED) { return false; } else { @@ -175,7 +176,8 @@ std::string WUVarReference::getTargetName() const //---------------------------------------------------------------------------- bool WUVarReference::isDuplicated() const { - if(getVar().access & VarAccessDuplication::SHARED) { + const unsigned int varAccess = getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); + if(varAccess & VarAccessDuplication::SHARED) { return false; } else { @@ -330,7 +332,9 @@ WUVarReference::WUVarReference(size_t varIndex, const Models::Base::VarVec &varV } // Check duplicatedness of variables - if((getVar().access & VarAccessDuplication::DUPLICATE) != (getTransposeVar().access & VarAccessDuplication::DUPLICATE)) { + const unsigned int varAccess = getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); + const unsigned int transposeVarAccess = getTransposeVar().access.value_or(static_cast(VarAccess::READ_WRITE)); + if((varAccess & VarAccessDuplication::DUPLICATE) != (transposeVarAccess & VarAccessDuplication::DUPLICATE)) { throw std::runtime_error("Transpose updates can only be performed on similarly batched variables"); } } diff --git a/src/genn/genn/neuronModels.cc b/src/genn/genn/neuronModels.cc index 8fe57b831a..152dc43881 100644 --- a/src/genn/genn/neuronModels.cc +++ b/src/genn/genn/neuronModels.cc @@ -50,7 +50,11 @@ void Base::validate(const std::unordered_map ¶mValues, // If any variables have a reduction access mode, give an error const auto vars = getVars(); if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v){ return (v.access & VarAccessModeAttribute::REDUCE); })) + [](const Models::Base::Var &v) + { + const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); + return (varAccess & VarAccessModeAttribute::REDUCE); + })) { throw std::runtime_error("Neuron models cannot include variables with REDUCE access modes - they are only supported by custom update models"); } diff --git a/src/genn/genn/postsynapticModels.cc b/src/genn/genn/postsynapticModels.cc index 27f04e789e..bf23c50a1a 100644 --- a/src/genn/genn/postsynapticModels.cc +++ b/src/genn/genn/postsynapticModels.cc @@ -35,7 +35,11 @@ void Base::validate(const std::unordered_map ¶mValues, const auto vars = getVars(); if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v){ return (v.access & VarAccessModeAttribute::REDUCE); })) + [](const Models::Base::Var &v) + { + const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); + return (varAccess & VarAccessModeAttribute::REDUCE); + })) { throw std::runtime_error("Postsynaptic models cannot include variables with REDUCE access modes - they are only supported by custom update models"); } diff --git a/src/genn/genn/weightUpdateModels.cc b/src/genn/genn/weightUpdateModels.cc index c6b399e548..9a2a1b8b24 100644 --- a/src/genn/genn/weightUpdateModels.cc +++ b/src/genn/genn/weightUpdateModels.cc @@ -85,11 +85,23 @@ void Base::validate(const std::unordered_map ¶mValues, const auto preVars = getPreVars(); const auto postVars = getPostVars(); if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v){ return (v.access & VarAccessModeAttribute::REDUCE); }) + [](const Models::Base::Var &v) + { + const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); + return (varAccess & VarAccessModeAttribute::REDUCE); + }) || std::any_of(preVars.cbegin(), preVars.cend(), - [](const Models::Base::Var &v){ return (v.access & VarAccessModeAttribute::REDUCE); }) + [](const Models::Base::Var &v) + { + const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); + return (varAccess & VarAccessModeAttribute::REDUCE); + }) || std::any_of(postVars.cbegin(), postVars.cend(), - [](const Models::Base::Var &v){ return (v.access & VarAccessModeAttribute::REDUCE); })) + [](const Models::Base::Var &v) + { + const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); + return (varAccess & VarAccessModeAttribute::REDUCE); + })) { throw std::runtime_error("Weight update models cannot include variables with REDUCE access modes - they are only supported by custom update models"); } @@ -102,7 +114,11 @@ void Base::validate(const std::unordered_map ¶mValues, // If any variables have shared neuron duplication mode, give an error if (std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v) { return (v.access & VarAccessDuplication::SHARED_NEURON); })) + [](const Models::Base::Var &v) + { + const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); + return (varAccess & VarAccessDuplication::SHARED_NEURON); + })) { throw std::runtime_error("Weight update models cannot include variables with SHARED_NEURON access modes - they are only supported on pre, postsynaptic or neuron variables"); } From 623719b33da41c3668882ccdc7f66000b939dce1 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 17 Aug 2023 10:09:24 +0100 Subject: [PATCH 03/60] added accessors for inserting default --- include/genn/backends/cuda/backend.h | 4 +- .../backends/single_threaded_cpu/backend.h | 5 ++- .../genn/genn/code_generator/backendBase.h | 4 +- .../customConnectivityUpdateGroupMerged.h | 2 +- .../code_generator/customUpdateGroupMerged.h | 4 +- .../genn/genn/code_generator/environment.h | 26 ++++++------- include/genn/genn/customUpdate.h | 10 ++--- include/genn/genn/models.h | 23 +++++++++++- .../backends/single_threaded_cpu/backend.cc | 6 +-- src/genn/genn/code_generator/backendBase.cc | 6 +-- .../customConnectivityUpdateGroupMerged.cc | 4 +- .../code_generator/customUpdateGroupMerged.cc | 6 +-- .../genn/code_generator/generateRunner.cc | 37 ++++++++++++------- .../genn/code_generator/initGroupMerged.cc | 9 +++-- .../code_generator/neuronUpdateGroupMerged.cc | 14 ++++--- src/genn/genn/currentSourceModels.cc | 3 +- src/genn/genn/customConnectivityUpdate.cc | 6 +-- .../genn/customConnectivityUpdateModels.cc | 21 ++--------- src/genn/genn/customUpdate.cc | 18 +++------ src/genn/genn/models.cc | 12 +++--- src/genn/genn/neuronModels.cc | 6 +-- src/genn/genn/postsynapticModels.cc | 6 +-- src/genn/genn/weightUpdateModels.cc | 21 ++--------- 23 files changed, 118 insertions(+), 135 deletions(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index b0e1470264..b74ad8e9d0 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -383,7 +383,7 @@ class BACKEND_EXPORT Backend : public BackendSIMT const auto *cm = cg.getArchetype().getCustomUpdateModel(); for(const auto &v : cm->getVars()) { // If variable is reduction target - if(v.access & VarAccessModeAttribute::REDUCE) { + if(v.getAccessMode() & VarAccessModeAttribute::REDUCE) { // Add pointer field const auto resolvedType = v.type.resolve(cg.getTypeContext()); groupEnv.addField(resolvedType.createPointer(), "_" + v.name, v.name, @@ -394,7 +394,7 @@ class BACKEND_EXPORT Backend : public BackendSIMT // Add NCCL reduction groupEnv.print("CHECK_NCCL_ERRORS(ncclAllReduce($(_" + v.name + "), $(_" + v.name + "), $(_size)"); - groupEnv.printLine(", " + getNCCLType(resolvedType) + ", " + getNCCLReductionType(getVarAccessMode(v.access)) + ", ncclCommunicator, 0));"); + groupEnv.printLine(", " + getNCCLType(resolvedType) + ", " + getNCCLReductionType(v.getAccessMode()) + ", ncclCommunicator, 0));"); } } diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 1ec0f238af..0d5010d6ec 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -252,9 +252,10 @@ class BACKEND_EXPORT Backend : public BackendBase const auto *cm = cg.getArchetype().getCustomUpdateModel(); for(const auto &v : cm->getVars()) { // If variable is a reduction target, copy value from register straight back into global memory - if(v.access & VarAccessModeAttribute::REDUCE) { + const unsigned int varAccess = v.getAccess(VarAccess::READ_WRITE); + if(varAccess & VarAccessModeAttribute::REDUCE) { const std::string idx = env.getName(idxName); - env.getStream() << "group->" << v.name << "[" << cg.getVarIndex(getVarAccessDuplication(v.access), idx) << "] = " << env[v.name] << ";" << std::endl; + env.getStream() << "group->" << v.name << "[" << cg.getVarIndex(getVarAccessDuplication(varAccess), idx) << "] = " << env[v.name] << ";" << std::endl; } } diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 215afe8e1d..dc5354b13a 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -568,10 +568,10 @@ class GENN_EXPORT BackendBase const auto *cm = cg.getArchetype().getCustomUpdateModel(); for (const auto &v : cm->getVars()) { // If variable is a reduction target, define variable initialised to correct initial value for reduction - const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); + const unsigned int varAccess = v.getAccess(VarAccess::READ_WRITE); if (varAccess & VarAccessModeAttribute::REDUCE) { const auto resolvedType = v.type.resolve(cg.getTypeContext()); - os << resolvedType.getName() << " _lr" << v.name << " = " << getReductionInitialValue(getVarAccessMode(v.access), resolvedType) << ";" << std::endl; + os << resolvedType.getName() << " _lr" << v.name << " = " << getReductionInitialValue(getVarAccessMode(varAccess), resolvedType) << ";" << std::endl; reductionTargets.push_back({v.name, resolvedType, getVarAccessMode(varAccess), cg.getVarIndex(getVarAccessDuplication(varAccess), idx)}); } diff --git a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h index 36a10362ee..616bbd3e58 100644 --- a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h @@ -129,7 +129,7 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public GroupMerged // Loop through variables and add pointers if they are reduction targets const auto *cm = this->getArchetype().getCustomUpdateModel(); for(const auto &v : cm->getVars()) { - if(v.access & VarAccessModeAttribute::REDUCE) { + if(v.getAccessMode() & VarAccessModeAttribute::REDUCE) { const auto fieldType = v.type.resolve(this->getTypeContext()).createPointer(); env.addField(fieldType, v.name, v.name, [&backend, v](const auto &g, size_t) @@ -153,7 +153,7 @@ class CustomUpdateHostReductionGroupMergedBase : public GroupMerged // Loop through variable references and add pointers if they are reduction targets for(const auto &v : cm->getVarRefs()) { - if(v.access & VarAccessModeAttribute::REDUCE) { + if(v.getAccessMode() & VarAccessModeAttribute::REDUCE) { const auto fieldType = v.type.resolve(this->getTypeContext()).createPointer(); env.addField(fieldType, v.name, v.name, [&backend, v](const auto &g, size_t) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index e5b5beb6b9..44e2c77807 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -405,7 +405,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase &(GroupInternal::*)(void) const; - using GetVarIndexFn = std::function, const std::string&)>; + using GetVarIndexFn = std::function; template using GetVarRefIndexFn = std::function; @@ -635,17 +635,18 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetGroup().getArchetype()); for(const auto &v : archetypeAdaptor.getDefs()) { const auto resolvedType = v.type.resolve(this->getGroup().getTypeContext()); - const auto qualifiedType = (readOnly || (getVarAccessMode(v.access) & VarAccessModeAttribute::READ_ONLY)) ? resolvedType.addConst() : resolvedType; + const auto qualifiedType = (readOnly || (v.getAccessMode() & VarAccessModeAttribute::READ_ONLY)) ? resolvedType.addConst() : resolvedType; addField(qualifiedType, v.name, resolvedType.createPointer(), v.name + fieldSuffix, [arrayPrefix, v](const auto &g, size_t) { return arrayPrefix + v.name + A(g).getNameSuffix(); }, - getIndexFn(v.access, v.name)); + getIndexFn(v.getAccess(VarAccess::READ_WRITE), v.name)); } } @@ -653,7 +654,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase(arrayPrefix, [&indexSuffix](std::optional, const std::string &) { return indexSuffix; }, + addVars(arrayPrefix, [&indexSuffix](unsigned int, const std::string &) { return indexSuffix; }, fieldSuffix, readOnly); } @@ -723,22 +724,19 @@ class VarCachePolicy bool shouldAlwaysCopy(G&, const Models::Base::Var &var) const { // **TODO** default from InitModel class - const unsigned int varAccess = var.access.value_or(static_cast(VarAccess::READ_WRITE)); - return m_ShouldAlwaysCopy(var.name, getVarAccessDuplication(varAccess)); + return m_ShouldAlwaysCopy(var.name, getVarAccessDuplication(var.getAccess(VarAccess::READ_WRITE))); } std::string getReadIndex(G&, const Models::Base::Var &var) const { // **TODO** default from InitModel class - const unsigned int varAccess = var.access.value_or(static_cast(VarAccess::READ_WRITE)); - return m_GetReadIndex(var.name, getVarAccessDuplication(varAccess)); + return m_GetReadIndex(var.name, getVarAccessDuplication(var.getAccess(VarAccess::READ_WRITE))); } std::string getWriteIndex(G&, const Models::Base::Var &var) const { // **TODO** default from InitModel class - const unsigned int varAccess = var.access.value_or(static_cast(VarAccess::READ_WRITE)); - return m_GetWriteIndex(var.name, getVarAccessDuplication(varAccess)); + return m_GetWriteIndex(var.name, getVarAccessDuplication(var.getAccess(VarAccess::READ_WRITE))); } std::string getTargetName(const GroupInternal &g, const Models::Base::Var &var) const @@ -862,7 +860,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P return arrayPrefix + this->getTargetName(group.getGroups().at(i), v); }); - if(v.access & VarAccessMode::READ_ONLY) { + if(v.getAccessMode() & VarAccessMode::READ_ONLY) { getContextStream() << "const "; } getContextStream() << resolvedType.getName() << " _" << m_LocalPrefix << v.name; @@ -870,7 +868,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P // If this isn't a reduction, read value from memory // **NOTE** by not initialising these variables for reductions, // compilers SHOULD emit a warning if user code doesn't set it to something - if(!(v.access & VarAccessModeAttribute::REDUCE)) { + if(!(v.getAccessMode() & VarAccessModeAttribute::REDUCE)) { getContextStream() << " = group->" << v.name << m_FieldSuffix << "[" << printSubs(this->getReadIndex(m_Group.get(), v), *this) << "]"; } getContextStream() << ";" << std::endl; @@ -882,7 +880,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P // Loop through referenced definitions again for(const auto &v : referencedDefs) { // If we should always copy variable or variable is read-write - if(this->shouldAlwaysCopy(m_Group.get(), v) || v.access & VarAccessMode::READ_WRITE) { + if(this->shouldAlwaysCopy(m_Group.get(), v) || v.getAccessMode() & VarAccessMode::READ_WRITE) { getContextStream() << "group->" << v.name << m_FieldSuffix << "[" << printSubs(this->getWriteIndex(m_Group.get(), v), *this) << "]"; getContextStream() << " = _" << m_LocalPrefix << v.name << ";" << std::endl; } @@ -906,7 +904,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P // Resolve type, add qualifier if required and return const auto resolvedType = var->second.second.type.resolve(m_Context.get()); - const auto qualifiedType = (var->second.second.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; + const auto qualifiedType = (var->second.second.getAccessMode() & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; return {qualifiedType}; } } diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index 242789c5de..967c76b8f6 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -93,7 +93,7 @@ class GENN_EXPORT CustomUpdateBase if(std::any_of(vars.cbegin(), vars.cend(), [duplication](const Models::Base::Var &v) { - const unsigned int access = v.access.value_or(static_cast(VarAccess::READ_WRITE)); + const unsigned int access = v.getAccess(VarAccess::READ_WRITE); return (access & VarAccessModeAttribute::REDUCE) && (access & duplication); })) { @@ -104,8 +104,9 @@ class GENN_EXPORT CustomUpdateBase for(const auto &modelVarRef : getCustomUpdateModel()->getVarRefs()) { // If custom update model reduces into this variable reference and the variable it targets has correct duplication flag const auto &varRef = varRefs.at(modelVarRef.name); - const unsigned int varAccess = varRef.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); - if ((modelVarRef.access & VarAccessModeAttribute::REDUCE) & (varAccess & duplication)) { + if ((modelVarRef.access & VarAccessModeAttribute::REDUCE) + && (varRef.getVar().getAccess(VarAccess::READ_WRITE) & duplication)) + { return true; } } @@ -132,8 +133,7 @@ class GENN_EXPORT CustomUpdateBase // If custom update is batched, check that any variable references to shared variables are read-only // **NOTE** if custom update isn't batched, it's totally fine to write to shared variables - const unsigned int varAccess = varRef.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); - if(m_Batched && (varAccess & VarAccessDuplication::SHARED) + if(m_Batched && (varRef.getVar().getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::SHARED) && (modelVarRef.access == VarAccessMode::READ_WRITE)) { throw std::runtime_error("Variable references to SHARED variables in batched custom updates cannot be read-write."); diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 1107df12ff..cddb3db408 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -80,6 +80,21 @@ class GENN_EXPORT Base : public Snippet::Base return (std::tie(name, type, access) == std::tie(other.name, other.type, other.access)); } + unsigned int getAccess(VarAccess defaultAccess) const + { + return access.value_or(static_cast(defaultAccess)); + } + + VarAccessMode getAccessMode() const + { + if(access) { + return getVarAccessMode(access.value()); + } + else { + return VarAccessMode::READ_WRITE; + } + } + std::string name; Type::UnresolvedType type; std::optional access; @@ -97,6 +112,11 @@ class GENN_EXPORT Base : public Snippet::Base return (std::tie(name, type, access) == std::tie(other.name, other.type, other.access)); } + VarAccessMode getAccessMode() const + { + return access; + } + std::string name; Type::UnresolvedType type; VarAccessMode access; @@ -406,8 +426,7 @@ void checkVarReferences(const std::unordered_map &varRefs, const // Check that no reduction targets reference duplicated variables // **TODO** default from InitModel class - const unsigned int varAccess = varRef.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); - if((varAccess & VarAccessDuplication::DUPLICATE) + if((varRef.getVar().getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::DUPLICATE) && (modelVarRef.access & VarAccessModeAttribute::REDUCE)) { throw std::runtime_error("Reduction target variable reference must be to SHARED or SHARED_NEURON variables."); diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index c3e7eb9e0e..3b347314c7 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -2016,9 +2016,8 @@ void Backend::genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateG genWriteBackReductions(env, cg, idxName, [&cg](const Models::VarReference &varRef, const std::string &index) { - const unsigned int varAccess = varRef.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); return cg.getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, - getVarAccessDuplication(varAccess), + getVarAccessDuplication(varRef.getVar().getAccess(VarAccess::READ_WRITE)), index); }); } @@ -2028,8 +2027,7 @@ void Backend::genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateW genWriteBackReductions(env, cg, idxName, [&cg](const Models::WUVarReference &varRef, const std::string &index) { - const unsigned int varAccess = varRef.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); - return cg.getVarRefIndex(getVarAccessDuplication(varAccess), + return cg.getVarRefIndex(getVarAccessDuplication(varRef.getVar().getAccess(VarAccess::READ_WRITE)), index); }); } diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index e3a27dba43..9602a40cfe 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -702,9 +702,8 @@ std::vector BackendBase::genInitReductionTargets(C return genInitReductionTargets(os, cg, idx, [&cg](const Models::VarReference &varRef, const std::string &index) { - const unsigned int varAccess = varRef.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); return cg.getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, - getVarAccessDuplication(varAccess), + getVarAccessDuplication(varRef.getVar().getAccess(VarAccess::READ_WRITE)), index); }); } @@ -714,8 +713,7 @@ std::vector BackendBase::genInitReductionTargets(C return genInitReductionTargets(os, cg, idx, [&cg](const Models::WUVarReference &varRef, const std::string &index) { - const unsigned int varAccess = varRef.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); - return cg.getVarRefIndex(getVarAccessDuplication(varAccess), + return cg.getVarRefIndex(getVarAccessDuplication(varRef.getVar().getAccess(VarAccess::READ_WRITE)), index); }); } diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index eecb5f42f3..80d18f4402 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -37,11 +37,11 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t { boost::uuids::detail::sha1 hashA; Type::updateHash(a.getVar().type, hashA); - Utils::updateHash(getVarAccessDuplication(a.getVar().access), hashA); + Utils::updateHash(getVarAccessDuplication(a.getVar().getAccess(VarAccess::READ_WRITE)), hashA); boost::uuids::detail::sha1 hashB; Type::updateHash(b.getVar().type, hashB); - Utils::updateHash(getVarAccessDuplication(b.getVar().access), hashB); + Utils::updateHash(getVarAccessDuplication(b.getVar().getAccess(VarAccess::READ_WRITE)), hashB); return (hashA.get_digest() < hashB.get_digest()); }); diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 901d8be7aa..e4d74fce7b 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -71,7 +71,7 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E [this, &varEnv](const std::string&, const Models::VarReference &v) { return getVarRefIndex(v.getDelayNeuronGroup() != nullptr, - getVarAccessDuplication(v.getVar().access), + getVarAccessDuplication(v.getVar().getAccess(VarAccess::READ_WRITE)), "$(id)"); }); @@ -195,8 +195,8 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back *this, *this, getTypeContext(), varEnv, backend.getDeviceVarPrefix(), "", "l", [this, &varEnv](const std::string&, const Models::WUVarReference &v) { - return getVarRefIndex(getVarAccessDuplication(v.getVar().access), - varEnv["id_syn"]); + return getVarRefIndex(getVarAccessDuplication(v.getVar().getAccess(VarAccess::READ_WRITE)), + "$(id_syn)"); }); Transpiler::ErrorHandler errorHandler("Custom update '" + getArchetype().getName() + "' update code"); diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index f4bb1a15af..4aa6f81c75 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -1081,8 +1081,9 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, std::vector neuronStatePushPullFunctions; for(const auto &var : neuronModel->getVars()) { const auto &varInit = n.second.getVarInitialisers().at(var.name); - const unsigned int numCopies = getNumVarCopies(var.access, batchSize); - const unsigned int numElements = getNumVarElements(var.access, n.second.getNumNeurons()); + const unsigned int varAccess = var.getAccess(VarAccess::READ_WRITE); + const unsigned int numCopies = getNumVarCopies(varAccess, batchSize); + const unsigned int numElements = getNumVarElements(varAccess, n.second.getNumNeurons()); const size_t count = n.second.isVarQueueRequired(var.name) ? numCopies * numElements * n.second.getNumDelaySlots() : numCopies * numElements; const bool autoInitialized = !Utils::areTokensEmpty(varInit.getCodeTokens()); const auto resolvedType = var.type.resolve(modelMerged.getModel().getTypeContext()); @@ -1152,7 +1153,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, runnerPushFunc, runnerPullFunc, *cs, mem, currentSourcePushPullFunctions, [batchSize, &n](const CurrentSourceInternal&, const Models::Base::Var &var) { - return getVarSize(var.access, n.second.getNumNeurons(), batchSize); + return getVarSize(var.getAccess(VarAccess::READ_WRITE), + n.second.getNumNeurons(), batchSize); }); genRunnerEGPs(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerExtraGlobalParamFunc, *cs); @@ -1175,7 +1177,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, runnerPushFunc, runnerPullFunc, model.getCustomUpdates(), mem, statePushPullFunctions, [batchSize](const CustomUpdateInternal &c, const Models::Base::Var &var) { - return getVarSize(var.access, c.getSize(), batchSize, 1, c.isBatched()); + return getVarSize(var.getAccess(VarAccess::READ_WRITE), + c.getSize(), batchSize, 1, c.isBatched()); }); genCustomUpdate( @@ -1189,7 +1192,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const size_t count = ((sg->getMatrixType() & SynapseMatrixWeight::KERNEL) ? sg->getKernelSizeFlattened() : sg->getSrcNeuronGroup()->getNumNeurons() * backend.getSynapticMatrixRowStride(*sg)); - return getVarSize(var.access, count, batchSize, 1, c.isBatched()); + return getVarSize(var.getAccess(VarAccess::READ_WRITE), + count, batchSize, 1, c.isBatched()); }); allVarStreams << std::endl; @@ -1270,7 +1274,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerExtraGlobalParamFunc, *sg, mem, [batchSize](const SynapseGroupInternal &sg, const Models::Base::Var &var) { - return getVarSize(var.access, sg.getTrgNeuronGroup()->getNumNeurons(), batchSize); + return getVarSize(var.getAccess(VarAccess::READ_WRITE), + sg.getTrgNeuronGroup()->getNumNeurons(), batchSize); }); } // Loop through fused outgoing synapse populations with weightupdate models that have presynaptic output @@ -1287,7 +1292,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerExtraGlobalParamFunc, *sg, mem, [batchSize, preDelaySlots](const SynapseGroupInternal &sg, const Models::Base::Var &var) { - return getVarSize(var.access, sg.getSrcNeuronGroup()->getNumNeurons(), + return getVarSize(var.getAccess(VarAccess::READ_WRITE), + sg.getSrcNeuronGroup()->getNumNeurons(), batchSize, preDelaySlots); }); } @@ -1299,7 +1305,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerExtraGlobalParamFunc, *sg, mem, [batchSize, postDelaySlots](const SynapseGroupInternal &sg, const Models::Base::Var &var) { - return getVarSize(var.access, sg.getTrgNeuronGroup()->getNumNeurons(), + return getVarSize(var.getAccess(VarAccess::READ_WRITE), + sg.getTrgNeuronGroup()->getNumNeurons(), batchSize, postDelaySlots); }); } @@ -1400,15 +1407,16 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const auto &varInit = s.second.getWUVarInitialisers().at(wuVar.name); const bool autoInitialized = !Utils::areTokensEmpty(varInit.getCodeTokens()); const auto resolvedType = wuVar.type.resolve(modelMerged.getModel().getTypeContext()); + const unsigned int wuVarAccess = wuVar.getAccess(VarAccess::READ_WRITE); if(individualWeights) { const size_t size = (size_t)s.second.getSrcNeuronGroup()->getNumNeurons() * (size_t)backend.getSynapticMatrixRowStride(s.second); genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerPushFunc, runnerPullFunc, resolvedType, wuVar.name + s.second.getName(), s.second.getWUVarLocation(wuVar.name), - autoInitialized, size * getNumVarCopies(wuVar.access, batchSize), mem, synapseGroupStatePushPullFunctions); + autoInitialized, size * getNumVarCopies(wuVarAccess, batchSize), mem, synapseGroupStatePushPullFunctions); } else if(kernelWeights) { // Calculate size of kernel - const size_t size = s.second.getKernelSizeFlattened() * getNumVarCopies(wuVar.access, batchSize); + const size_t size = s.second.getKernelSizeFlattened() * getNumVarCopies(wuVarAccess, batchSize); // Generate variable genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, @@ -1443,7 +1451,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, s.second, synapseGroupStatePushPullFunctions, [batchSize](const SynapseGroupInternal &sg, const Models::Base::Var &var) { - return getVarSize(var.access, sg.getTrgNeuronGroup()->getNumNeurons(), batchSize); + return getVarSize(var.getAccess(VarAccess::READ_WRITE), + sg.getTrgNeuronGroup()->getNumNeurons(), batchSize); }); } @@ -1455,7 +1464,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, s.second, synapseGroupStatePushPullFunctions, [batchSize, preDelaySlots](const SynapseGroupInternal &sg, const Models::Base::Var &var) { - return getVarSize(var.access, sg.getSrcNeuronGroup()->getNumNeurons(), batchSize, preDelaySlots); + return getVarSize(var.getAccess(VarAccess::READ_WRITE), sg.getSrcNeuronGroup()->getNumNeurons(), + batchSize, preDelaySlots); }); } @@ -1468,7 +1478,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, s.second, synapseGroupStatePushPullFunctions, [batchSize, postDelaySlots](const SynapseGroupInternal &sg, const Models::Base::Var &var) { - return getVarSize(var.access, sg.getTrgNeuronGroup()->getNumNeurons(), batchSize, postDelaySlots); + return getVarSize(var.getAccess(VarAccess::READ_WRITE), sg.getTrgNeuronGroup()->getNumNeurons(), + batchSize, postDelaySlots); }); } diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index de98684730..f90363efcc 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -87,7 +87,7 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e }); // If variable is shared between neurons - if (getVarAccessDuplication(var.access) == VarAccessDuplication::SHARED_NEURON) { + if (getVarAccessDuplication(var.getAccess(VarAccess::READ_WRITE)) == VarAccessDuplication::SHARED_NEURON) { backend.genPopVariableInit( varEnv, [&adaptor, &fieldGroup, &fieldSuffix, &group, &resolvedType, &var, &varInit, batchSize, numDelaySlots] @@ -103,7 +103,7 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e prettyPrintStatements(varInit.getCodeTokens(), group.getTypeContext(), varInitEnv, errorHandler); // Fill value across all delay slots and batches - genScalarFill(varInitEnv, "_value", "$(value)", getVarAccessDuplication(var.access), + genScalarFill(varInitEnv, "_value", "$(value)", getVarAccessDuplication(var.getAccess(VarAccess::READ_WRITE)), batchSize, adaptor.isVarDelayed(var.name), numDelaySlots); }); } @@ -125,7 +125,8 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e // Fill value across all delay slots and batches genVariableFill(varInitEnv, "_value", "$(value)", "id", "$(" + count + ")", - getVarAccessDuplication(var.access), batchSize, adaptor.isVarDelayed(var.name), numDelaySlots); + getVarAccessDuplication(var.getAccess(VarAccess::READ_WRITE)), + batchSize, adaptor.isVarDelayed(var.name), numDelaySlots); }); } } @@ -184,7 +185,7 @@ void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, // Fill value across all batches genVariableFill(varInitEnv, "_value", "$(value)", "id_syn", stride, - getVarAccessDuplication(var.access), batchSize); + getVarAccessDuplication(var.getAccess(VarAccess::READ_WRITE)), batchSize); }); } } diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index a871b43515..83b9a70907 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -242,9 +242,10 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::genCopyDelayedVars(EnvironmentEx // Loop through variables and copy between read and write delay slots // **YUCK** this a bit sketchy as fields may not have been added - could add fields here but need to guarantee uniqueness for(const auto &v : getArchetype().getWUModel()->getPostVars()) { - if(v.access & VarAccessMode::READ_WRITE) { - env.print("group->" + v.name + suffix + "[" + ng.getWriteVarIndex(true, batchSize, getVarAccessDuplication(v.access), "$(id)") + "] = "); - env.printLine("group->" + v.name + suffix + "[" + ng.getReadVarIndex(true, batchSize, getVarAccessDuplication(v.access), "$(id)") + "];"); + const unsigned int varAccess = v.getAccess(VarAccess::READ_WRITE); + if(varAccess & VarAccessMode::READ_WRITE) { + env.print("group->" + v.name + suffix + "[" + ng.getWriteVarIndex(true, batchSize, getVarAccessDuplication(varAccess), "$(id)") + "] = "); + env.printLine("group->" + v.name + suffix + "[" + ng.getReadVarIndex(true, batchSize, getVarAccessDuplication(varAccess), "$(id)") + "];"); } } } @@ -333,9 +334,10 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::genCopyDelayedVars(EnvironmentEx // Loop through variables and copy between read and write delay slots // **YUCK** this a bit sketchy as fields may not have been added - could add fields here but need to guarantee uniqueness for(const auto &v : getArchetype().getWUModel()->getPreVars()) { - if(v.access & VarAccessMode::READ_WRITE) { - env.print("group->" + v.name + suffix + "[" + ng.getWriteVarIndex(true, batchSize, getVarAccessDuplication(v.access), "$(id)") + "] = "); - env.printLine("group->" + v.name + suffix + "[" + ng.getReadVarIndex(true, batchSize, getVarAccessDuplication(v.access), "$(id)") + "];"); + const unsigned int varAccess = v.getAccess(VarAccess::READ_WRITE); + if(varAccess & VarAccessMode::READ_WRITE) { + env.print("group->" + v.name + suffix + "[" + ng.getWriteVarIndex(true, batchSize, getVarAccessDuplication(varAccess), "$(id)") + "] = "); + env.printLine("group->" + v.name + suffix + "[" + ng.getReadVarIndex(true, batchSize, getVarAccessDuplication(varAccess), "$(id)") + "];"); } } } diff --git a/src/genn/genn/currentSourceModels.cc b/src/genn/genn/currentSourceModels.cc index 8f650a5db4..61838f7238 100644 --- a/src/genn/genn/currentSourceModels.cc +++ b/src/genn/genn/currentSourceModels.cc @@ -37,8 +37,7 @@ void Base::validate(const std::unordered_map ¶mValues, if(std::any_of(vars.cbegin(), vars.cend(), [](const Models::Base::Var &v) { - const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); - return (varAccess & VarAccessModeAttribute::REDUCE); + return (v.getAccessMode() & VarAccessModeAttribute::REDUCE); })) { throw std::runtime_error("Current source models cannot include variables with REDUCE access modes - they are only supported by custom update models"); diff --git a/src/genn/genn/customConnectivityUpdate.cc b/src/genn/genn/customConnectivityUpdate.cc index d55c598678..0e41dd8010 100644 --- a/src/genn/genn/customConnectivityUpdate.cc +++ b/src/genn/genn/customConnectivityUpdate.cc @@ -180,8 +180,7 @@ void CustomConnectivityUpdate::finalise(double dt, unsigned int batchSize) if (std::any_of(getPreVarReferences().cbegin(), getPreVarReferences().cend(), [](const auto &v) { - const unsigned int varAccess = v.second.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); - return (getVarAccessDuplication(varAccess) != VarAccessDuplication::SHARED); + return (getVarAccessDuplication(v.second.getVar().getAccess(VarAccess::READ_WRITE)) != VarAccessDuplication::SHARED); })) { throw std::runtime_error("Presynaptic variables referenced by CustomConnectivityUpdate must be SHARED across batches"); @@ -191,8 +190,7 @@ void CustomConnectivityUpdate::finalise(double dt, unsigned int batchSize) if (std::any_of(getPostVarReferences().cbegin(), getPostVarReferences().cend(), [](const auto &v) { - const unsigned int varAccess = v.second.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); - return (getVarAccessDuplication(varAccess) != VarAccessDuplication::SHARED); + return (getVarAccessDuplication(v.second.getVar().getAccess(VarAccess::READ_WRITE)) != VarAccessDuplication::SHARED); })) { throw std::runtime_error("Postsynaptic variables referenced by CustomConnectivityUpdate must be SHARED across batches"); diff --git a/src/genn/genn/customConnectivityUpdateModels.cc b/src/genn/genn/customConnectivityUpdateModels.cc index 48ab5b5fb4..f2274039c2 100644 --- a/src/genn/genn/customConnectivityUpdateModels.cc +++ b/src/genn/genn/customConnectivityUpdateModels.cc @@ -61,23 +61,11 @@ void Base::validate(const std::unordered_map ¶mValues, // If any variables have a reduction access mode, give an error // **YUCK** copy-paste from WUM - could go in helper/Models::Base if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v) - { - const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); - return (varAccess & VarAccessModeAttribute::REDUCE); - }) + [](const Models::Base::Var &v){ return (v.getAccessMode() & VarAccessModeAttribute::REDUCE); }) || std::any_of(preVars.cbegin(), preVars.cend(), - [](const Models::Base::Var &v) - { - const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); - return (varAccess & VarAccessModeAttribute::REDUCE); - }) + [](const Models::Base::Var &v){ return (v.getAccessMode() & VarAccessModeAttribute::REDUCE); }) || std::any_of(postVars.cbegin(), postVars.cend(), - [](const Models::Base::Var &v) - { - const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); - return (varAccess & VarAccessModeAttribute::REDUCE); - })) + [](const Models::Base::Var &v){ return (v.getAccessMode() & VarAccessModeAttribute::REDUCE); })) { throw std::runtime_error("Custom connectivity update models cannot include variables with REDUCE access modes - they are only supported by custom update models"); } @@ -87,8 +75,7 @@ void Base::validate(const std::unordered_map ¶mValues, if (std::any_of(vars.cbegin(), vars.cend(), [](const Models::Base::Var &v) { - const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); - return (varAccess & VarAccessDuplication::SHARED_NEURON); + return (v.getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::SHARED_NEURON); })) { throw std::runtime_error("Custom connectivity update models cannot include variables with SHARED_NEURON access modes - they are only supported on pre, postsynaptic or neuron variables"); diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index 0d15cecca7..76902a785a 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -139,14 +139,12 @@ CustomUpdate::CustomUpdate(const std::string &name, const std::string &updateGro m_PerNeuron = std::any_of(m_VarReferences.cbegin(), m_VarReferences.cend(), [](const auto& v) { - const unsigned int varAccess = v.second.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); - return !(varAccess & VarAccessDuplication::SHARED_NEURON); + return !(v.second.getVar().getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::SHARED_NEURON); }); m_PerNeuron |= std::any_of(modelVars.cbegin(), modelVars.cend(), [](const Models::Base::Var& v) { - const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); - return !(varAccess & VarAccessDuplication::SHARED_NEURON); + return !(v.getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::SHARED_NEURON); }); // Loop through all variable references @@ -155,8 +153,7 @@ CustomUpdate::CustomUpdate(const std::string &name, const std::string &updateGro // If custom update is per-neuron, check that any variable references to SHARED_NEURON variables are read-only // **NOTE** if custom update isn't per-neuron, it's totally fine to write to SHARED_NEURON variables - const unsigned int varAccess = varRef.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); - if(m_PerNeuron && (varAccess & VarAccessDuplication::SHARED_NEURON) + if(m_PerNeuron && (varRef.getVar().getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::SHARED_NEURON) && (modelVarRef.access == VarAccessMode::READ_WRITE)) { throw std::runtime_error("Variable references to SHARED_NEURON variables in per-neuron custom updates cannot be read-write."); @@ -222,8 +219,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdate::getHashDigest() const Utils::updateHash((v.second.getDelayNeuronGroup() == nullptr), hash); // Update hash with duplication mode of target variable as this effects indexing code - const unsigned int varAccess = v.second.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); - Utils::updateHash(getVarAccessDuplication(varAccess), hash); + Utils::updateHash(getVarAccessDuplication(v.second.getVar().getAccess(VarAccess::READ_WRITE)), hash); } return hash.get_digest(); } @@ -275,8 +271,7 @@ CustomUpdateWU::CustomUpdateWU(const std::string &name, const std::string &updat if (std::any_of(vars.cbegin(), vars.cend(), [](const Models::Base::Var &v) { - const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); - return (varAccess & VarAccessDuplication::SHARED_NEURON); + return (v.getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::SHARED_NEURON); })) { throw std::runtime_error("Custom weight updates cannot use models with SHARED_NEURON variables."); @@ -346,8 +341,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWU::getHashDigest() const Utils::updateHash((v.second.getTransposeSynapseGroup() == nullptr), hash); // Update hash with duplication mode of target variable as this effects indexing code - const unsigned int varAccess = v.second.getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); - Utils::updateHash(getVarAccessDuplication(varAccess), hash); + Utils::updateHash(getVarAccessDuplication(v.second.getVar().getAccess(VarAccess::READ_WRITE)), hash); } return hash.get_digest(); diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index c5d8a5ca77..3790ae7b8f 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -75,8 +75,7 @@ std::string VarReference::getTargetName() const //---------------------------------------------------------------------------- bool VarReference::isDuplicated() const { - const unsigned int varAccess = getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); - if(varAccess & VarAccessDuplication::SHARED) { + if(getVar().getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::SHARED) { return false; } else { @@ -176,8 +175,7 @@ std::string WUVarReference::getTargetName() const //---------------------------------------------------------------------------- bool WUVarReference::isDuplicated() const { - const unsigned int varAccess = getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); - if(varAccess & VarAccessDuplication::SHARED) { + if(getVar().getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::SHARED) { return false; } else { @@ -332,9 +330,9 @@ WUVarReference::WUVarReference(size_t varIndex, const Models::Base::VarVec &varV } // Check duplicatedness of variables - const unsigned int varAccess = getVar().access.value_or(static_cast(VarAccess::READ_WRITE)); - const unsigned int transposeVarAccess = getTransposeVar().access.value_or(static_cast(VarAccess::READ_WRITE)); - if((varAccess & VarAccessDuplication::DUPLICATE) != (transposeVarAccess & VarAccessDuplication::DUPLICATE)) { + if((getVar().getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::DUPLICATE) + != (getTransposeVar().getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::DUPLICATE)) + { throw std::runtime_error("Transpose updates can only be performed on similarly batched variables"); } } diff --git a/src/genn/genn/neuronModels.cc b/src/genn/genn/neuronModels.cc index 152dc43881..da52358947 100644 --- a/src/genn/genn/neuronModels.cc +++ b/src/genn/genn/neuronModels.cc @@ -50,11 +50,7 @@ void Base::validate(const std::unordered_map ¶mValues, // If any variables have a reduction access mode, give an error const auto vars = getVars(); if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v) - { - const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); - return (varAccess & VarAccessModeAttribute::REDUCE); - })) + [](const Models::Base::Var &v){ return (v.getAccessMode() & VarAccessModeAttribute::REDUCE); })) { throw std::runtime_error("Neuron models cannot include variables with REDUCE access modes - they are only supported by custom update models"); } diff --git a/src/genn/genn/postsynapticModels.cc b/src/genn/genn/postsynapticModels.cc index bf23c50a1a..2713a9b89b 100644 --- a/src/genn/genn/postsynapticModels.cc +++ b/src/genn/genn/postsynapticModels.cc @@ -35,11 +35,7 @@ void Base::validate(const std::unordered_map ¶mValues, const auto vars = getVars(); if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v) - { - const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); - return (varAccess & VarAccessModeAttribute::REDUCE); - })) + [](const Models::Base::Var &v){ return (v.getAccessMode() & VarAccessModeAttribute::REDUCE); })) { throw std::runtime_error("Postsynaptic models cannot include variables with REDUCE access modes - they are only supported by custom update models"); } diff --git a/src/genn/genn/weightUpdateModels.cc b/src/genn/genn/weightUpdateModels.cc index 9a2a1b8b24..d1234f78eb 100644 --- a/src/genn/genn/weightUpdateModels.cc +++ b/src/genn/genn/weightUpdateModels.cc @@ -85,23 +85,11 @@ void Base::validate(const std::unordered_map ¶mValues, const auto preVars = getPreVars(); const auto postVars = getPostVars(); if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v) - { - const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); - return (varAccess & VarAccessModeAttribute::REDUCE); - }) + [](const Models::Base::Var &v){ return (v.getAccessMode() & VarAccessModeAttribute::REDUCE); }) || std::any_of(preVars.cbegin(), preVars.cend(), - [](const Models::Base::Var &v) - { - const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); - return (varAccess & VarAccessModeAttribute::REDUCE); - }) + [](const Models::Base::Var &v){ return (v.getAccessMode() & VarAccessModeAttribute::REDUCE); }) || std::any_of(postVars.cbegin(), postVars.cend(), - [](const Models::Base::Var &v) - { - const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); - return (varAccess & VarAccessModeAttribute::REDUCE); - })) + [](const Models::Base::Var &v){ return (v.getAccessMode() & VarAccessModeAttribute::REDUCE); })) { throw std::runtime_error("Weight update models cannot include variables with REDUCE access modes - they are only supported by custom update models"); } @@ -116,8 +104,7 @@ void Base::validate(const std::unordered_map ¶mValues, if (std::any_of(vars.cbegin(), vars.cend(), [](const Models::Base::Var &v) { - const unsigned int varAccess = v.access.value_or(static_cast(VarAccess::READ_WRITE)); - return (varAccess & VarAccessDuplication::SHARED_NEURON); + return (v.getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::SHARED_NEURON); })) { throw std::runtime_error("Weight update models cannot include variables with SHARED_NEURON access modes - they are only supported on pre, postsynaptic or neuron variables"); From deb43674b2a30b61783d65988ce49fa0e95f53dc Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 17 Aug 2023 17:34:55 +0100 Subject: [PATCH 04/60] * Remove VarAccess and VarAccessDuplication * Replace queries with VarAccessDim --- .../backends/single_threaded_cpu/backend.h | 2 +- .../genn/genn/code_generator/backendBase.h | 2 +- .../code_generator/customUpdateGroupMerged.h | 8 +- .../genn/genn/code_generator/environment.h | 4 +- .../code_generator/neuronUpdateGroupMerged.h | 6 +- .../code_generator/synapseUpdateGroupMerged.h | 32 +++---- include/genn/genn/customUpdate.h | 25 +++--- include/genn/genn/varAccess.h | 54 ++++------- .../backends/single_threaded_cpu/backend.cc | 4 +- src/genn/genn/code_generator/backendBase.cc | 4 +- src/genn/genn/code_generator/backendSIMT.cc | 8 +- .../customConnectivityUpdateGroupMerged.cc | 4 +- .../code_generator/customUpdateGroupMerged.cc | 45 +++++----- .../genn/code_generator/generateRunner.cc | 4 +- .../genn/code_generator/initGroupMerged.cc | 55 ++++++------ .../code_generator/neuronUpdateGroupMerged.cc | 90 ++++++++++--------- .../presynapticUpdateStrategySIMT.cc | 10 +-- .../synapseUpdateGroupMerged.cc | 79 ++++++++-------- src/genn/genn/customConnectivityUpdate.cc | 6 +- .../genn/customConnectivityUpdateModels.cc | 2 +- src/genn/genn/customUpdate.cc | 18 ++-- src/genn/genn/models.cc | 22 ++--- 22 files changed, 237 insertions(+), 247 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 0d5010d6ec..534e6ced0b 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -255,7 +255,7 @@ class BACKEND_EXPORT Backend : public BackendBase const unsigned int varAccess = v.getAccess(VarAccess::READ_WRITE); if(varAccess & VarAccessModeAttribute::REDUCE) { const std::string idx = env.getName(idxName); - env.getStream() << "group->" << v.name << "[" << cg.getVarIndex(getVarAccessDuplication(varAccess), idx) << "] = " << env[v.name] << ";" << std::endl; + env.getStream() << "group->" << v.name << "[" << cg.getVarIndex(getVarAccessDim(varAccess), idx) << "] = " << env[v.name] << ";" << std::endl; } } diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index dc5354b13a..ccf9880b37 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -573,7 +573,7 @@ class GENN_EXPORT BackendBase const auto resolvedType = v.type.resolve(cg.getTypeContext()); os << resolvedType.getName() << " _lr" << v.name << " = " << getReductionInitialValue(getVarAccessMode(varAccess), resolvedType) << ";" << std::endl; reductionTargets.push_back({v.name, resolvedType, getVarAccessMode(varAccess), - cg.getVarIndex(getVarAccessDuplication(varAccess), idx)}); + cg.getVarIndex(varAccess), idx)}); } } diff --git a/include/genn/genn/code_generator/customUpdateGroupMerged.h b/include/genn/genn/code_generator/customUpdateGroupMerged.h index ab6e73d4c1..fbaa8dc7ad 100644 --- a/include/genn/genn/code_generator/customUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customUpdateGroupMerged.h @@ -32,8 +32,8 @@ class GENN_EXPORT CustomUpdateGroupMerged : public GroupMerged genPostamble); - std::string getVarIndex(VarAccessDuplication varDuplication, const std::string &index) const; - std::string getVarRefIndex(bool delay, VarAccessDuplication varDuplication, const std::string &index) const; + std::string getVarIndex(unsigned int varAccess, const std::string &index) const; + std::string getVarRefIndex(bool delay, unsigned int varAccess, const std::string &index) const; //---------------------------------------------------------------------------- // Static constants @@ -67,8 +67,8 @@ class GENN_EXPORT CustomUpdateWUGroupMergedBase : public GroupMerged genPostamble); - std::string getVarIndex(VarAccessDuplication varDuplication, const std::string &index) const; - std::string getVarRefIndex(VarAccessDuplication varDuplication, const std::string &index) const; + std::string getVarIndex(unsigned int varAccess, const std::string &index) const; + std::string getVarRefIndex(unsigned int varAccess, const std::string &index) const; }; diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 44e2c77807..3ef7d5b24d 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -704,8 +704,8 @@ class VarCachePolicy { public: using GroupInternal = typename G::GroupInternal; - using GetIndexFn = std::function; - using ShouldAlwaysCopyFn = std::function; + using GetIndexFn = std::function; + using ShouldAlwaysCopyFn = std::function; VarCachePolicy(GetIndexFn getReadIndex, GetIndexFn getWriteIndex, ShouldAlwaysCopyFn shouldAlwaysCopy = ShouldAlwaysCopyFn()) diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index b8a00dac1a..a9bd8edd4a 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -185,9 +185,9 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase void generateWUVarUpdate(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize); - std::string getVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; - std::string getReadVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; - std::string getWriteVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; + std::string getVarIndex(unsigned int batchSize, unsigned int varAccess, const std::string &index) const; + std::string getReadVarIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const; + std::string getWriteVarIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const; const std::vector &getMergedCurrentSourceGroups() const { return m_MergedCurrentSourceGroups; } const std::vector &getMergedInSynPSMGroups() const { return m_MergedInSynPSMGroups; } diff --git a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h index 4ea6197995..1920602071 100644 --- a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h @@ -44,33 +44,33 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMergedisDelayRequired(), batchSize, varDuplication, index); + return getPreVarIndex(getArchetype().getSrcNeuronGroup()->isDelayRequired(), batchSize, varAccess, index); } - std::string getPostVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const + std::string getPostVarIndex(unsigned int batchSize, unsigned int varAccess, const std::string &index) const { - return getPostVarIndex(getArchetype().getTrgNeuronGroup()->isDelayRequired(), batchSize, varDuplication, index); + return getPostVarIndex(getArchetype().getTrgNeuronGroup()->isDelayRequired(), batchSize, varAccess, index); } - std::string getPreWUVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const + std::string getPreWUVarIndex(unsigned int batchSize, unsigned int varAccess, const std::string &index) const { - return getPreVarIndex(getArchetype().getDelaySteps() != 0, batchSize, varDuplication, index); + return getPreVarIndex(getArchetype().getDelaySteps() != 0, batchSize, varAccess, index); } - std::string getPostWUVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const + std::string getPostWUVarIndex(unsigned int batchSize, unsigned int varAccess, const std::string &index) const { - return getPostVarIndex(getArchetype().getBackPropDelaySteps() != 0, batchSize, varDuplication, index); + return getPostVarIndex(getArchetype().getBackPropDelaySteps() != 0, batchSize, varAccess, index); } std::string getPostDenDelayIndex(unsigned int batchSize, const std::string &index, const std::string &offset) const; - std::string getPreVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; - std::string getPostVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; + std::string getPreVarIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const; + std::string getPostVarIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const; - std::string getPrePrevSpikeTimeIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; - std::string getPostPrevSpikeTimeIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; + std::string getPrePrevSpikeTimeIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const; + std::string getPostPrevSpikeTimeIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const; std::string getPostISynIndex(unsigned int batchSize, const std::string &index) const { @@ -82,8 +82,8 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged getUpdateCodeTokens() const{ return m_UpdateCodeTokens; } template - bool isReduction(const std::unordered_map &varRefs, VarAccessDuplication duplication) const + bool isReduction(const std::unordered_map &varRefs, VarAccessDim reduceDim) const { - // Return true if any variables have REDUCE flag in their access mode and have correct duplication flag + // Return true if any variables have REDUCE flag in their access mode and doesn't have reduction dimension const auto vars = getCustomUpdateModel()->getVars(); if(std::any_of(vars.cbegin(), vars.cend(), - [duplication](const Models::Base::Var &v) + [reduceDim](const Models::Base::Var &v) { const unsigned int access = v.getAccess(VarAccess::READ_WRITE); - return (access & VarAccessModeAttribute::REDUCE) && (access & duplication); + return (access & VarAccessModeAttribute::REDUCE) && !(access & reduceDim); })) { return true; @@ -102,10 +102,11 @@ class GENN_EXPORT CustomUpdateBase // Loop through all variable references for(const auto &modelVarRef : getCustomUpdateModel()->getVarRefs()) { - // If custom update model reduces into this variable reference and the variable it targets has correct duplication flag + // If custom update model reduces into this variable reference + // and the variable it targets doesn't have reduction dimension const auto &varRef = varRefs.at(modelVarRef.name); if ((modelVarRef.access & VarAccessModeAttribute::REDUCE) - && (varRef.getVar().getAccess(VarAccess::READ_WRITE) & duplication)) + && !(varRef.getVar().getAccess(VarAccess::READ_WRITE) & reduceDim)) { return true; } @@ -131,12 +132,12 @@ class GENN_EXPORT CustomUpdateBase for(const auto &modelVarRef : getCustomUpdateModel()->getVarRefs()) { const auto varRef = varRefs.at(modelVarRef.name); - // If custom update is batched, check that any variable references to shared variables are read-only + // If custom update is batched, check that any variable references to variables that aren't batched are read-only // **NOTE** if custom update isn't batched, it's totally fine to write to shared variables - if(m_Batched && (varRef.getVar().getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::SHARED) + if(m_Batched && !(varRef.getVar().getAccess(VarAccess::READ_WRITE) & VarAccessDim::BATCH) && (modelVarRef.access == VarAccessMode::READ_WRITE)) { - throw std::runtime_error("Variable references to SHARED variables in batched custom updates cannot be read-write."); + throw std::runtime_error("Variable references to non-batched variables in batched custom updates cannot be read-write."); } } } @@ -264,8 +265,8 @@ class GENN_EXPORT CustomUpdate : public CustomUpdateBase //------------------------------------------------------------------------ // Protected const methods //------------------------------------------------------------------------ - bool isBatchReduction() const { return isReduction(getVarReferences(), VarAccessDuplication::SHARED); } - bool isNeuronReduction() const { return isReduction(getVarReferences(), VarAccessDuplication::SHARED_NEURON); } + bool isBatchReduction() const { return isReduction(getVarReferences(), VarAccessDim::BATCH); } + bool isNeuronReduction() const { return isReduction(getVarReferences(), VarAccessDim::NEURON); } bool isPerNeuron() const{ return m_PerNeuron; } const NeuronGroup *getDelayNeuronGroup() const { return m_DelayNeuronGroup; } @@ -321,7 +322,7 @@ class GENN_EXPORT CustomUpdateWU : public CustomUpdateBase //------------------------------------------------------------------------ // Protected const methods //------------------------------------------------------------------------ - bool isBatchReduction() const { return isReduction(getVarReferences(), VarAccessDuplication::SHARED); } + bool isBatchReduction() const { return isReduction(getVarReferences(), VarAccessDim::BATCH); } bool isTransposeOperation() const; SynapseGroupInternal *getSynapseGroup() const { return m_SynapseGroup; } diff --git a/include/genn/genn/varAccess.h b/include/genn/genn/varAccess.h index 2c1cd56022..8c17395a83 100644 --- a/include/genn/genn/varAccess.h +++ b/include/genn/genn/varAccess.h @@ -25,22 +25,13 @@ enum class VarAccessMode : unsigned int REDUCE_MAX = static_cast(VarAccessModeAttribute::REDUCE) | static_cast(VarAccessModeAttribute::MAX), }; -//! Flags defining how variables should be duplicated across multiple batches -enum class VarAccessDuplication : unsigned int -{ - DUPLICATE = (1 << 5), //! This variable should be duplicated in each batch - SHARED = (1 << 6), //! This variable should be shared between batches - SHARED_NEURON = (1 << 7) //! This variable should be shared between neurons -}; - //! Flags defining dimensions this variables has enum class VarAccessDim : unsigned int { NEURON = (1 << 5), PRE_NEURON = (1 << 6), POST_NEURON = (1 << 7), - DELAY = (1 << 8), - BATCH = (1 << 9), + BATCH = (1 << 8), }; //! Supported combinations of access mode and dimension for neuron variables @@ -53,7 +44,7 @@ enum class NeuronVarAccess : unsigned int }; //! Supported combinations of access mode and dimension for synapse variables -/*enum class SynapseVarAccess : unsigned int +enum class SynapseVarAccess : unsigned int { // Synaptic variables READ_WRITE = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDim::PRE_NEURON) | static_cast(VarAccessDim::POST_NEURON) | static_cast(VarAccessDim::BATCH), @@ -61,17 +52,17 @@ enum class NeuronVarAccess : unsigned int READ_ONLY_DUPLICATE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::PRE_NEURON) | static_cast(VarAccessDim::POST_NEURON) | static_cast(VarAccessDim::BATCH), // Presynaptic variables - READ_WRITE_PRE = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDim::PRE_NEURON) | static_cast(VarAccessDim::BATCH), - READ_ONLY_PRE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::PRE_NEURON), - READ_ONLY_PRE_DUPLICATE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::PRE_NEURON) | static_cast(VarAccessDim::BATCH), + //READ_WRITE_PRE = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDim::PRE_NEURON) | static_cast(VarAccessDim::BATCH), + //READ_ONLY_PRE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::PRE_NEURON), + //READ_ONLY_PRE_DUPLICATE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::PRE_NEURON) | static_cast(VarAccessDim::BATCH), // Postsynaptic variables - READ_WRITE_POST = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDim::POST_NEURON) | static_cast(VarAccessDim::BATCH), - READ_ONLY_POST = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::POST_NEURON), - READ_ONLY_POST_DUPLICATE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::POST_NEURON) | static_cast(VarAccessDim::BATCH), + //READ_WRITE_POST = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDim::POST_NEURON) | static_cast(VarAccessDim::BATCH), + //READ_ONLY_POST = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::POST_NEURON), + //READ_ONLY_POST_DUPLICATE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::POST_NEURON) | static_cast(VarAccessDim::BATCH), }; -enum class CustomUpdateVarAccess : unsigned int +/*enum class CustomUpdateVarAccess : unsigned int { // Variables with matching shape READ_WRITE, @@ -95,19 +86,6 @@ enum class CustomUpdateVarAccess : unsigned int REDUCE_POST_NEURON_MAX, }*/ -//! Supported combinations of VarAccessMode and VarAccessDuplication -enum class VarAccess : unsigned int -{ - READ_WRITE = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDuplication::DUPLICATE), - READ_ONLY = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDuplication::SHARED), - READ_ONLY_SHARED_NEURON = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDuplication::SHARED_NEURON), - READ_ONLY_DUPLICATE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDuplication::DUPLICATE), - REDUCE_BATCH_SUM = static_cast(VarAccessMode::REDUCE_SUM) | static_cast(VarAccessDuplication::SHARED), - REDUCE_BATCH_MAX = static_cast(VarAccessMode::REDUCE_MAX) | static_cast(VarAccessDuplication::SHARED), - REDUCE_NEURON_SUM = static_cast(VarAccessMode::REDUCE_SUM) | static_cast(VarAccessDuplication::SHARED_NEURON), - REDUCE_NEURON_MAX = static_cast(VarAccessMode::REDUCE_MAX) | static_cast(VarAccessDuplication::SHARED_NEURON), -}; - //---------------------------------------------------------------------------- // Operators //---------------------------------------------------------------------------- @@ -116,9 +94,9 @@ inline bool operator & (unsigned int type, VarAccessMode mode) return (type & static_cast(mode)) != 0; } -inline bool operator & (unsigned int type, VarAccessDuplication duplication) +inline bool operator & (unsigned int type, VarAccessDim dim) { - return (type & static_cast(duplication)) != 0; + return (type & static_cast(dim)) != 0; } inline bool operator & (unsigned int type, VarAccessModeAttribute modeAttribute) @@ -136,6 +114,11 @@ inline bool operator & (VarAccessMode a, VarAccessMode b) return (static_cast(a) & static_cast(b)) != 0; } +inline unsigned int operator | (VarAccessDim a, VarAccessDim b) +{ + return (static_cast(a) | static_cast(b)); +} + //---------------------------------------------------------------------------- // Helpers @@ -144,9 +127,4 @@ inline VarAccessMode getVarAccessMode(unsigned int type) { return static_cast(type & 0x1F); } - -inline VarAccessDuplication getVarAccessDuplication(unsigned int type) -{ - return static_cast(type & ~0x1F); -} } // namespace GeNN diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 3b347314c7..4d62c764d1 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -2017,7 +2017,7 @@ void Backend::genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateG [&cg](const Models::VarReference &varRef, const std::string &index) { return cg.getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, - getVarAccessDuplication(varRef.getVar().getAccess(VarAccess::READ_WRITE)), + varRef.getVar().getAccess(NeuronVarAccess::READ_WRITE), index); }); } @@ -2027,7 +2027,7 @@ void Backend::genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateW genWriteBackReductions(env, cg, idxName, [&cg](const Models::WUVarReference &varRef, const std::string &index) { - return cg.getVarRefIndex(getVarAccessDuplication(varRef.getVar().getAccess(VarAccess::READ_WRITE)), + return cg.getVarRefIndex(varRef.getVar().getAccess(SynapseVarAccess::READ_WRITE), index); }); } diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 9602a40cfe..7ef2780dc9 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -703,7 +703,7 @@ std::vector BackendBase::genInitReductionTargets(C [&cg](const Models::VarReference &varRef, const std::string &index) { return cg.getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, - getVarAccessDuplication(varRef.getVar().getAccess(VarAccess::READ_WRITE)), + varRef.getVar().getAccess(NeuronVarAccess::READ_WRITE), index); }); } @@ -713,7 +713,7 @@ std::vector BackendBase::genInitReductionTargets(C return genInitReductionTargets(os, cg, idx, [&cg](const Models::WUVarReference &varRef, const std::string &index) { - return cg.getVarRefIndex(getVarAccessDuplication(varRef.getVar().getAccess(VarAccess::READ_WRITE)), + return cg.getVarRefIndex(varRef.getVar().getAccess(SynapseVarAccess::READ_WRITE), index); }); } diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 6a96fed619..03c1cbe1de 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -514,7 +514,7 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM // Add population RNG field groupEnv.addField(getPopulationRNGType().createPointer(), "_rng", "rng", [this](const auto &g, size_t) { return getDeviceVarPrefix() + "rng" + g.getName(); }, - ng.getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "$(id)")); + ng.getVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id)")); // **TODO** for OCL do genPopulationRNGPreamble(os, popSubs, "group->rng[" + ng.getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "$(id)") + "]") in initialiser ng.generateNeuronUpdate(*this, groupEnv, batchSize, @@ -587,10 +587,10 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM } // Copy spikes into block of $(_spk) - const std::string queueOffset = ng.getWriteVarIndex(ng.getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, ""); + const std::string queueOffset = ng.getWriteVarIndex(ng.getArchetype().isDelayRequired(), batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, ""); if(!Utils::areTokensEmpty(ng.getArchetype().getThresholdConditionCodeTokens())) { const std::string queueOffsetTrueSpk = ng.getWriteVarIndex(ng.getArchetype().isTrueSpikeRequired() && ng.getArchetype().isDelayRequired(), - batchSize, VarAccessDuplication::DUPLICATE, ""); + batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, ""); groupEnv.print("if(" + getThreadID() + " < $(_sh_spk_count))"); { CodeStream::Scope b(groupEnv.getStream()); @@ -805,7 +805,7 @@ void BackendSIMT::genPostsynapticUpdateKernel(EnvironmentExternalBase &env, Mode { CodeStream::Scope b(groupEnv.getStream()); const std::string index = "(r * " + std::to_string(getKernelBlockSize(KernelPostsynapticUpdate)) + ") + " + getThreadID(); - groupEnv.printLine("const unsigned int spk = $(_trg_spk)[" + sg.getPostVarIndex(batchSize, VarAccessDuplication::DUPLICATE, index) + "];"); + groupEnv.printLine("const unsigned int spk = $(_trg_spk)[" + sg.getPostVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::POST_NEURON, index) + "];"); groupEnv.getStream() << "shSpk[" << getThreadID() << "] = spk;" << std::endl; if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 80d18f4402..1871a8679b 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -37,11 +37,11 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t { boost::uuids::detail::sha1 hashA; Type::updateHash(a.getVar().type, hashA); - Utils::updateHash(getVarAccessDuplication(a.getVar().getAccess(VarAccess::READ_WRITE)), hashA); + Utils::updateHash(a.getVar().getAccess(VarAccess::READ_WRITE), hashA); boost::uuids::detail::sha1 hashB; Type::updateHash(b.getVar().type, hashB); - Utils::updateHash(getVarAccessDuplication(b.getVar().getAccess(VarAccess::READ_WRITE)), hashB); + Utils::updateHash(b.getVar().getAccess(VarAccess::READ_WRITE), hashB); return (hashA.get_digest() < hashB.get_digest()); }); diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index e4d74fce7b..f26ea1d336 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -60,7 +60,7 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", - [this, &cuEnv](const std::string&, VarAccessDuplication d) + [this, &cuEnv](const std::string&, unsigned int d) { return getVarIndex(d, "$(id)"); }); @@ -71,7 +71,7 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E [this, &varEnv](const std::string&, const Models::VarReference &v) { return getVarRefIndex(v.getDelayNeuronGroup() != nullptr, - getVarAccessDuplication(v.getVar().getAccess(VarAccess::READ_WRITE)), + v.getVar().getAccess(NeuronVarAccess::READ_WRITE), "$(id)"); }); @@ -82,40 +82,43 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E genPostamble(varRefEnv, *this); } //---------------------------------------------------------------------------- -std::string CustomUpdateGroupMerged::getVarIndex(VarAccessDuplication varDuplication, const std::string &index) const +std::string CustomUpdateGroupMerged::getVarIndex(unsigned int varAccess, const std::string &index) const { // **YUCK** there's a lot of duplication in these methods - do they belong elsewhere? - if (varDuplication == VarAccessDuplication::SHARED_NEURON) { - return getArchetype().isBatched() ? "$(batch)" : "0"; + const bool batched = (varAccess & VarAccessDim::BATCH) && getArchetype().isBatched(); + if (!(varAccess & VarAccessDim::NEURON)) { + return batched ? "$(batch)" : "0"; } - else if (varDuplication == VarAccessDuplication::SHARED || !getArchetype().isBatched()) { + else if (batched) { assert(!index.empty()); - return index; + return "$(_batch_offset) + " + index; } else { assert(!index.empty()); - return "$(_batch_offset) + " + index; + return index; } } //---------------------------------------------------------------------------- -std::string CustomUpdateGroupMerged::getVarRefIndex(bool delay, VarAccessDuplication varDuplication, const std::string &index) const +std::string CustomUpdateGroupMerged::getVarRefIndex(bool delay, unsigned int varAccess, const std::string &index) const { // If delayed, variable is shared, the batch size is one or this custom update isn't batched, batch delay offset isn't required if(delay) { - if (varDuplication == VarAccessDuplication::SHARED_NEURON) { - return getArchetype().isBatched() ? "$(_batch_delay_slot)" : "$(_delay_slot)"; + const bool batched = (varAccess & VarAccessDim::BATCH) && getArchetype().isBatched(); + if (!(varAccess & VarAccessDim::NEURON)) { + return batched ? "$(_batch_delay_slot)" : "$(_delay_slot)"; } - else if (varDuplication == VarAccessDuplication::SHARED || !getArchetype().isBatched()) { + else if (batched) { assert(!index.empty()); - return "$(_delay_offset) + " + index; + return "$(_batch_delay_offset) + " + index; } + else { assert(!index.empty()); - return "$(_batch_delay_offset) + " + index; + return "$(_delay_offset) + " + index; } } else { - return getVarIndex(varDuplication, index); + return getVarIndex(varAccess, index); } } //---------------------------------------------------------------------------- @@ -185,7 +188,7 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", - [this, &cuEnv](const std::string&, VarAccessDuplication d) + [this, &cuEnv](const std::string&, unsigned int d) { return getVarIndex(d, "$(id_syn)"); }); @@ -195,7 +198,7 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back *this, *this, getTypeContext(), varEnv, backend.getDeviceVarPrefix(), "", "l", [this, &varEnv](const std::string&, const Models::WUVarReference &v) { - return getVarRefIndex(getVarAccessDuplication(v.getVar().getAccess(VarAccess::READ_WRITE)), + return getVarRefIndex(v.getVar().getAccess(SynapseVarAccess::READ_WRITE), "$(id_syn)"); }); @@ -206,16 +209,16 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back genPostamble(varRefEnv, *this); } //---------------------------------------------------------------------------- -std::string CustomUpdateWUGroupMergedBase::getVarIndex(VarAccessDuplication varDuplication, const std::string &index) const +std::string CustomUpdateWUGroupMergedBase::getVarIndex(unsigned int varAccess, const std::string &index) const { // **YUCK** there's a lot of duplication in these methods - do they belong elsewhere? - return ((varDuplication == VarAccessDuplication::SHARED || !getArchetype().isBatched()) ? "" : "$(_batch_offset) + ") + index; + return (((varAccess & VarAccessDim::BATCH) && getArchetype().isBatched()) ? "$(_batch_offset) + " : "") + index; } //---------------------------------------------------------------------------- -std::string CustomUpdateWUGroupMergedBase::getVarRefIndex(VarAccessDuplication varDuplication, const std::string &index) const +std::string CustomUpdateWUGroupMergedBase::getVarRefIndex(unsigned int varAccess, const std::string &index) const { // **YUCK** there's a lot of duplication in these methods - do they belong elsewhere? - return ((varDuplication == VarAccessDuplication::SHARED || !getArchetype().isBatched()) ? "" : "$(_batch_offset) + ") + index; + return (((varAccess & VarAccessDim::BATCH) && getArchetype().isBatched()) ? "$(_batch_offset) + " : "") + index; } // ---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 4aa6f81c75..1013b8efe1 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -29,12 +29,12 @@ namespace { unsigned int getNumVarCopies(unsigned int varAccess, unsigned int batchSize, bool batched = true) { - return ((varAccess & VarAccessDuplication::SHARED) || !batched) ? 1 : batchSize; + return ((varAccess & VarAccessDim::BATCH) && batched) ? batchSize : 1; } //-------------------------------------------------------------------------- unsigned int getNumVarElements(unsigned int varAccess, unsigned int numNeurons) { - return (varAccess & VarAccessDuplication::SHARED_NEURON) ? 1 : numNeurons; + return (varAccess & VarAccessDim::NEURON) ? numNeurons : 1; } //-------------------------------------------------------------------------- unsigned int getVarSize(unsigned int varAccess, unsigned int numElements, unsigned int batchSize, diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index f90363efcc..fecac43363 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -21,10 +21,10 @@ using namespace GeNN::Transpiler; namespace { void genVariableFill(EnvironmentExternalBase &env, const std::string &target, const std::string &value, const std::string &idx, const std::string &stride, - VarAccessDuplication varDuplication, unsigned int batchSize, bool delay = false, unsigned int numDelaySlots = 1) + unsigned int varAccess, unsigned int batchSize, bool delay = false, unsigned int numDelaySlots = 1) { // Determine number of values to fill in each thread - const unsigned int numValues = ((varDuplication == VarAccessDuplication::SHARED) ? 1 : batchSize) * ((delay ? numDelaySlots : 1)); + const unsigned int numValues = ((varAccess & VarAccessDim::BATCH) ? batchSize : 1) * ((delay ? numDelaySlots : 1)); // If there's only one, don't generate a loop if(numValues == 1) { @@ -41,10 +41,10 @@ void genVariableFill(EnvironmentExternalBase &env, const std::string &target, co } //-------------------------------------------------------------------------- void genScalarFill(EnvironmentExternalBase &env, const std::string &target, const std::string &value, - VarAccessDuplication varDuplication, unsigned int batchSize, bool delay = false, unsigned int numDelaySlots = 1) + unsigned int varAccess, unsigned int batchSize, bool delay = false, unsigned int numDelaySlots = 1) { // Determine number of values to fill in each thread - const unsigned int numValues = ((varDuplication == VarAccessDuplication::SHARED) ? 1 : batchSize) * ((delay ? numDelaySlots : 1)); + const unsigned int numValues = ((varAccess & VarAccessDim::BATCH) ? batchSize : 1) * ((delay ? numDelaySlots : 1)); // If there's only one, don't generate a loop if(numValues == 1) { @@ -86,11 +86,12 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e return backend.getDeviceVarPrefix() + var.name + A(g).getNameSuffix(); }); - // If variable is shared between neurons - if (getVarAccessDuplication(var.getAccess(VarAccess::READ_WRITE)) == VarAccessDuplication::SHARED_NEURON) { - backend.genPopVariableInit( - varEnv, - [&adaptor, &fieldGroup, &fieldSuffix, &group, &resolvedType, &var, &varInit, batchSize, numDelaySlots] + // If variable has NEURON axis + const unsigned int varAccess = var.getAccess(NeuronVarAccess::READ_WRITE); + if (varAccess & VarAccessDim::NEURON) { + backend.genVariableInit( + varEnv, count, "id", + [&adaptor, &fieldGroup, &fieldSuffix, &group, &var, &resolvedType, &varInit, batchSize, count, numDelaySlots] (EnvironmentExternalBase &env) { // Generate initial value into temporary variable @@ -101,17 +102,17 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e // Pretty print variable initialisation code Transpiler::ErrorHandler errorHandler("Group '" + group.getArchetype().getName() + "' variable '" + var.name + "' init code"); prettyPrintStatements(varInit.getCodeTokens(), group.getTypeContext(), varInitEnv, errorHandler); - + // Fill value across all delay slots and batches - genScalarFill(varInitEnv, "_value", "$(value)", getVarAccessDuplication(var.getAccess(VarAccess::READ_WRITE)), - batchSize, adaptor.isVarDelayed(var.name), numDelaySlots); + genVariableFill(varInitEnv, "_value", "$(value)", "id", "$(" + count + ")", + varAccess, batchSize, adaptor.isVarDelayed(var.name), numDelaySlots); }); } // Otherwise else { - backend.genVariableInit( - varEnv, count, "id", - [&adaptor, &fieldGroup, &fieldSuffix, &group, &var, &resolvedType, &varInit, batchSize, count, numDelaySlots] + backend.genPopVariableInit( + varEnv, + [&adaptor, &fieldGroup, &fieldSuffix, &group, &resolvedType, &var, &varInit, batchSize, numDelaySlots] (EnvironmentExternalBase &env) { // Generate initial value into temporary variable @@ -122,11 +123,10 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e // Pretty print variable initialisation code Transpiler::ErrorHandler errorHandler("Group '" + group.getArchetype().getName() + "' variable '" + var.name + "' init code"); prettyPrintStatements(varInit.getCodeTokens(), group.getTypeContext(), varInitEnv, errorHandler); - + // Fill value across all delay slots and batches - genVariableFill(varInitEnv, "_value", "$(value)", "id", "$(" + count + ")", - getVarAccessDuplication(var.getAccess(VarAccess::READ_WRITE)), - batchSize, adaptor.isVarDelayed(var.name), numDelaySlots); + genScalarFill(varInitEnv, "_value", "$(value)", varAccess, + batchSize, adaptor.isVarDelayed(var.name), numDelaySlots); }); } } @@ -185,7 +185,7 @@ void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, // Fill value across all batches genVariableFill(varInitEnv, "_value", "$(value)", "id_syn", stride, - getVarAccessDuplication(var.getAccess(VarAccess::READ_WRITE)), batchSize); + var.getAccess(SynapseVarAccess::READ_WRITE), batchSize); }); } } @@ -222,7 +222,7 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir [batchSize, this] (EnvironmentExternalBase &varEnv) { genVariableFill(varEnv, "_out_post", Type::writeNumeric(0.0, getScalarType()), - "id", "$(num_neurons)", VarAccessDuplication::DUPLICATE, batchSize); + "id", "$(num_neurons)", VarAccessDim::BATCH | VarAccessDim::NEURON, batchSize); }); @@ -235,7 +235,7 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir [batchSize, this](EnvironmentExternalBase &varEnv) { genVariableFill(varEnv, "_den_delay", Type::writeNumeric(0.0, getScalarType()), - "id", "$(num_neurons)", VarAccessDuplication::DUPLICATE, + "id", "$(num_neurons)", VarAccessDim::BATCH | VarAccessDim::NEURON, batchSize, true, getArchetype().getMaxDendriticDelayTimesteps()); }); @@ -269,7 +269,7 @@ void NeuronInitGroupMerged::OutSynPreOutput::generate(const BackendBase &backend [batchSize, this] (EnvironmentExternalBase &varEnv) { genVariableFill(varEnv, "_out_pre", Type::writeNumeric(0.0, getScalarType()), - "id", "$(num_neurons)", VarAccessDuplication::DUPLICATE, batchSize); + "id", "$(num_neurons)", VarAccessDim::BATCH | VarAccessDim::NEURON, batchSize); }); } @@ -450,7 +450,8 @@ void NeuronInitGroupMerged::genInitSpikeCount(const BackendBase &backend, Enviro (getArchetype().isTrueSpikeRequired() && getArchetype().isDelayRequired()); // Zero across all delay slots and batches - genScalarFill(spikeCountEnv, "_spk_cnt", "0", VarAccessDuplication::DUPLICATE, batchSize, delayRequired, getArchetype().getNumDelaySlots()); + genScalarFill(spikeCountEnv, "_spk_cnt", "0", VarAccessDim::BATCH | VarAccessDim::NEURON, + batchSize, delayRequired, getArchetype().getNumDelaySlots()); }); } @@ -480,8 +481,8 @@ void NeuronInitGroupMerged::genInitSpikes(const BackendBase &backend, Environmen (getArchetype().isTrueSpikeRequired() && getArchetype().isDelayRequired()); // Zero across all delay slots and batches - genVariableFill(varEnv, "_spk", "0", "id", "$(num_neurons)", - VarAccessDuplication::DUPLICATE, batchSize, delayRequired, getArchetype().getNumDelaySlots()); + genVariableFill(varEnv, "_spk", "0", "id", "$(num_neurons)", VarAccessDim::BATCH | VarAccessDim::NEURON, + batchSize, delayRequired, getArchetype().getNumDelaySlots()); }); } } @@ -499,7 +500,7 @@ void NeuronInitGroupMerged::genInitSpikeTime(const BackendBase &backend, Environ backend.genVariableInit(env, "num_neurons", "id", [batchSize, varName, this] (EnvironmentExternalBase &varEnv) { - genVariableFill(varEnv, varName, "-TIME_MAX", "id", "$(num_neurons)", VarAccessDuplication::DUPLICATE, + genVariableFill(varEnv, varName, "-TIME_MAX", "id", "$(num_neurons)", VarAccessDim::BATCH | VarAccessDim::NEURON, batchSize, getArchetype().isDelayRequired(), getArchetype().getNumDelaySlots()); }); } diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 83b9a70907..64e3162582 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -40,7 +40,7 @@ void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), csEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", - [batchSize, &ng](const std::string&, VarAccessDuplication d) + [batchSize, &ng](const std::string&, unsigned int d) { return ng.getVarIndex(batchSize, d, "$(id)"); }); @@ -83,7 +83,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); // Read into local variable - const std::string idx = ng.getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "$(id)"); + const std::string idx = ng.getVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id)"); psmEnv.getStream() << "// postsynaptic model " << getIndex() << std::endl; psmEnv.printLine(getScalarType().getName() + " linSyn = $(_out_post)[" + idx + "];"); @@ -121,7 +121,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), psmEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", - [batchSize, &ng](const std::string&, VarAccessDuplication d) + [batchSize, &ng](const std::string&, unsigned int d) { return ng.getVarIndex(batchSize, d, "$(id)"); }); @@ -134,7 +134,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env prettyPrintStatements(getArchetype().getPSDecayCodeTokens(), getTypeContext(), varEnv, decayErrorHandler); // Write back linSyn - varEnv.printLine("$(_out_post)[" + ng.getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "$(id)") + "] = linSyn;"); + varEnv.printLine("$(_out_post)[" + ng.getVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id)") + "] = linSyn;"); } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::InSynPSM::updateHash(boost::uuids::detail::sha1 &hash) const @@ -168,7 +168,7 @@ void NeuronUpdateGroupMerged::OutSynPreOutput::generate(const BackendBase &backe [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); // Add reverse insyn variable to - const std::string idx = ng.getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "$(id)"); + const std::string idx = ng.getVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id)"); outSynEnv.printLine(getArchetype().getPreTargetVar() + " += $(_out_pre)[" + idx + "];"); // Zero it again @@ -202,15 +202,15 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back const bool delayed = (getArchetype().getBackPropDelaySteps() != NO_DELAY); EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", - [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccessDuplication d) + [batchSize, delayed, &synEnv, &ng](const std::string&, unsigned int d) { return ng.getReadVarIndex(delayed, batchSize, d, "$(id)"); }, - [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccessDuplication d) + [batchSize, delayed, &synEnv, &ng](const std::string&, unsigned int d) { return ng.getWriteVarIndex(delayed, batchSize, d, "$(id)"); }, - [delayed](const std::string&, VarAccessDuplication) + [delayed](const std::string&, unsigned int) { return delayed; }); @@ -242,10 +242,10 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::genCopyDelayedVars(EnvironmentEx // Loop through variables and copy between read and write delay slots // **YUCK** this a bit sketchy as fields may not have been added - could add fields here but need to guarantee uniqueness for(const auto &v : getArchetype().getWUModel()->getPostVars()) { - const unsigned int varAccess = v.getAccess(VarAccess::READ_WRITE); + const unsigned int varAccess = v.getAccess(NeuronVarAccess::READ_WRITE); if(varAccess & VarAccessMode::READ_WRITE) { - env.print("group->" + v.name + suffix + "[" + ng.getWriteVarIndex(true, batchSize, getVarAccessDuplication(varAccess), "$(id)") + "] = "); - env.printLine("group->" + v.name + suffix + "[" + ng.getReadVarIndex(true, batchSize, getVarAccessDuplication(varAccess), "$(id)") + "];"); + env.print("group->" + v.name + suffix + "[" + ng.getWriteVarIndex(true, batchSize, varAccess, "$(id)") + "] = "); + env.printLine("group->" + v.name + suffix + "[" + ng.getReadVarIndex(true, batchSize, varAccess, "$(id)") + "];"); } } } @@ -294,15 +294,15 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back const bool delayed = (getArchetype().getDelaySteps() != NO_DELAY); EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", - [batchSize, delayed, &ng](const std::string&, VarAccessDuplication d) + [batchSize, delayed, &ng](const std::string&, unsigned int d) { return ng.getReadVarIndex(delayed, batchSize, d, "$(id)"); }, - [batchSize, delayed, &ng](const std::string&, VarAccessDuplication d) + [batchSize, delayed, &ng](const std::string&, unsigned int d) { return ng.getWriteVarIndex(delayed, batchSize, d, "$(id)"); }, - [delayed](const std::string&, VarAccessDuplication) + [delayed](const std::string&, unsigned int) { return delayed; }); @@ -334,10 +334,10 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::genCopyDelayedVars(EnvironmentEx // Loop through variables and copy between read and write delay slots // **YUCK** this a bit sketchy as fields may not have been added - could add fields here but need to guarantee uniqueness for(const auto &v : getArchetype().getWUModel()->getPreVars()) { - const unsigned int varAccess = v.getAccess(VarAccess::READ_WRITE); + const unsigned int varAccess = v.getAccess(NeuronVarAccess::READ_WRITE); if(varAccess & VarAccessMode::READ_WRITE) { - env.print("group->" + v.name + suffix + "[" + ng.getWriteVarIndex(true, batchSize, getVarAccessDuplication(varAccess), "$(id)") + "] = "); - env.printLine("group->" + v.name + suffix + "[" + ng.getReadVarIndex(true, batchSize, getVarAccessDuplication(varAccess), "$(id)") + "];"); + env.print("group->" + v.name + suffix + "[" + ng.getWriteVarIndex(true, batchSize, varAccess, "$(id)") + "] = "); + env.printLine("group->" + v.name + suffix + "[" + ng.getReadVarIndex(true, batchSize, varAccess, "$(id)") + "];"); } } } @@ -496,7 +496,8 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // Substitute spike times const std::string timePrecision = getTimeType().getName(); - const std::string spikeTimeReadIndex = getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, "$(id)"); + const std::string spikeTimeReadIndex = getReadVarIndex(getArchetype().isDelayRequired(), batchSize, + VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id)"); neuronEnv.add(getTimeType().addConst(), "st", "lsT", {neuronEnv.addInitialiser("const " + timePrecision + " lsT = $(_st)[" + spikeTimeReadIndex + "];")}); neuronEnv.add(getTimeType().addConst(), "prev_st", "lprevST", @@ -511,17 +512,17 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // **NOTE** always copy variables if variable is delayed EnvironmentLocalVarCache neuronVarEnv( *this, *this, getTypeContext(), neuronEnv, backend.getDeviceVarPrefix(), "", "l", - [batchSize, &neuronEnv, this](const std::string &varName, VarAccessDuplication d) + [batchSize, &neuronEnv, this](const std::string &varName, unsigned int d) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); return getReadVarIndex(delayed, batchSize, d, "$(id)") ; }, - [batchSize, &neuronEnv, this](const std::string &varName, VarAccessDuplication d) + [batchSize, &neuronEnv, this](const std::string &varName, unsigned int d) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); return getWriteVarIndex(delayed, batchSize, d, "$(id)") ; }, - [this](const std::string &varName, VarAccessDuplication) + [this](const std::string &varName, unsigned int) { return (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); }); @@ -703,12 +704,12 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // If spike times are required, copy times from register if(getArchetype().isSpikeTimeRequired()) { - neuronVarEnv.printLine("$(_st)[" + getWriteVarIndex(true, batchSize, VarAccessDuplication::DUPLICATE, "$(id)") + "] = $(st);"); + neuronVarEnv.printLine("$(_st)[" + getWriteVarIndex(true, batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id)") + "] = $(st);"); } // If previous spike times are required, copy times from register if(getArchetype().isPrevSpikeTimeRequired()) { - neuronVarEnv.printLine("$(_prev_st)[" + getWriteVarIndex(true, batchSize, VarAccessDuplication::DUPLICATE, "$(id)") + "] = $(prev_st);"); + neuronVarEnv.printLine("$(_prev_st)[" + getWriteVarIndex(true, batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON "$(id)") + "] = $(prev_st);"); } // Loop through outgoing synapse groups with some sort of presynaptic code @@ -740,53 +741,56 @@ void NeuronUpdateGroupMerged::generateWUVarUpdate(const BackendBase &backend, En } } //-------------------------------------------------------------------------- -std::string NeuronUpdateGroupMerged::getVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const +std::string NeuronUpdateGroupMerged::getVarIndex(unsigned int batchSize, unsigned int varAccess, const std::string &index) const { // **YUCK** there's a lot of duplication in these methods - do they belong elsewhere? - if (varDuplication == VarAccessDuplication::SHARED_NEURON) { - return (batchSize == 1) ? "0" : "$(batch)"; + const bool batched = ((varAccess & VarAccessDim::BATCH) && batchSize >= 1); + if (!(varAccess & VarAccessDim::NEURON)) { + return batched ? "$(batch)" : "0"; } - else if(varDuplication == VarAccessDuplication::SHARED || batchSize == 1) { - return index; + else if(batched) { + return "$(_batch_offset) + " + index; } else { - return "$(_batch_offset) + " + index; + return index; } } //-------------------------------------------------------------------------- -std::string NeuronUpdateGroupMerged::getReadVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const +std::string NeuronUpdateGroupMerged::getReadVarIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const { if(delay) { - if (varDuplication == VarAccessDuplication::SHARED_NEURON) { - return (batchSize == 1) ? "$(_read_delay_slot)" : "$(_read_batch_delay_slot)"; + const bool batched = ((varAccess & VarAccessDim::BATCH) && batchSize >= 1); + if (!(varAccess & VarAccessDim::NEURON)) { + return batched ? "$(_read_batch_delay_slot)" : "$(_read_delay_slot)"; } - else if (varDuplication == VarAccessDuplication::SHARED || batchSize == 1) { - return "$(_read_delay_offset) + " + index; + else if(batched) { + return "$(_read_batch_delay_offset) + " + index; } else { - return "$(_read_batch_delay_offset) + " + index; + return "$(_read_delay_offset) + " + index; } } else { - return getVarIndex(batchSize, varDuplication, index); + return getVarIndex(batchSize, varAccess, index); } } //-------------------------------------------------------------------------- -std::string NeuronUpdateGroupMerged::getWriteVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const +std::string NeuronUpdateGroupMerged::getWriteVarIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const { if(delay) { - if (varDuplication == VarAccessDuplication::SHARED_NEURON) { - return (batchSize == 1) ? "$(_write_delay_slot)" : "$(_write_batch_delay_slot)"; + const bool batched = ((varAccess & VarAccessDim::BATCH) && batchSize >= 1); + if (!(varAccess & VarAccessDim::NEURON)) { + return batched ? "$(_write_batch_delay_slot)" : "$(_write_delay_slot)"; } - else if (varDuplication == VarAccessDuplication::SHARED || batchSize == 1) { - return "$(_write_delay_offset) + " + index; + else if (batched) { + return "$(_write_batch_delay_offset) + " + index; } else { - return "$(_write_batch_delay_offset) + " + index; + return "$(_write_delay_offset) + " + index; } } else { - return getVarIndex(batchSize, varDuplication, index); + return getVarIndex(batchSize, varAccess, index); } } //---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc index 2d21e541f6..73de775c6b 100644 --- a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc +++ b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc @@ -96,7 +96,7 @@ void PreSpan::genUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMerg { CodeStream::Scope b(env.getStream()); - env.printLine("const unsigned int preInd = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "spike") + "];"); + env.printLine("const unsigned int preInd = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::PRE_NEURON, "spike") + "];"); const auto indexType = backend.getSynapseIndexType(sg); const auto indexTypeName = indexType.getName(); @@ -247,7 +247,7 @@ void PostSpan::genUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMer { CodeStream::Scope b(env.getStream()); const std::string index = "(r * " + std::to_string(backend.getKernelBlockSize(KernelPresynapticUpdate)) + ") + " + backend.getThreadID(); - env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDuplication::DUPLICATE, index) + "];"); + env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::PRE_NEURON, index) + "];"); env.printLine("$(_sh_spk" + eventSuffix + ")[" + backend.getThreadID() + "] = spk;"); if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { env.printLine("$(_sh_row_length)[" + backend.getThreadID() + "] = $(_row_length)[spk];"); @@ -459,7 +459,7 @@ void PreSpanProcedural::genUpdate(EnvironmentExternalBase &env, PresynapticUpdat // Create environment and add presynaptic index EnvironmentGroupMergedField synEnv(groupEnv, sg); synEnv.add(Type::Uint32.addConst(), "id_pre", "preInd", - {synEnv.addInitialiser("const unsigned int preInd = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "$(_spike)") + "];")}); + {synEnv.addInitialiser("const unsigned int preInd = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::PRE_NEURON, "$(_spike)") + "];")}); // **YUCK** add a hidden copy of num_post so we can overwrite deeper in here without losing access to original synEnv.add(Type::Uint32.addConst(), "_num_post", "$(num_post)"); @@ -639,7 +639,7 @@ void PostSpanBitmask::genUpdate(EnvironmentExternalBase &env, PresynapticUpdateG { CodeStream::Scope b(env.getStream()); const std::string index = "(r * " + std::to_string(backend.getKernelBlockSize(KernelPresynapticUpdate)) + ") + " + backend.getThreadID(); - env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDuplication::DUPLICATE, index) + "];"); + env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::PRE_NEURON, index) + "];"); env.printLine("$(_sh_spk" + eventSuffix + ")[" + backend.getThreadID() + "] = spk;"); } backend.genSharedMemBarrier(env.getStream()); @@ -873,7 +873,7 @@ void PostSpanToeplitz::genUpdate(EnvironmentExternalBase &env, PresynapticUpdate { CodeStream::Scope b(env.getStream()); const std::string index = "(r * " + std::to_string(backend.getKernelBlockSize(KernelPresynapticUpdate)) + ") + " + backend.getThreadID(); - env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDuplication::DUPLICATE, index) + "];"); + env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::PRE_NEURON, index) + "];"); env.printLine("$(_sh_spk" + eventSuffix + ")[" + backend.getThreadID() + "] = spk;"); } backend.genSharedMemBarrier(env.getStream()); diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 264285061a..9bddae801c 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -32,13 +32,13 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa backend.getDeviceVarPrefix(), [&sg, batchSize](unsigned int a, const std::string&) { - return sg.getPreWUVarIndex(batchSize, getVarAccessDuplication(a), "$(id_pre)"); + return sg.getPreWUVarIndex(batchSize, a, "$(id_pre)"); }, "", true); synEnv.template addVars( backend.getDeviceVarPrefix(), [&sg, batchSize](unsigned int a, const std::string&) { - return sg.getPostWUVarIndex(batchSize, getVarAccessDuplication(a), "$(id_post)"); + return sg.getPostWUVarIndex(batchSize, a, "$(id_post)"); }, "", true); @@ -53,8 +53,8 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa const std::string timeStr = sg.getTimeType().getName(); const std::string axonalDelayMs = Type::writeNumeric(dt * (double)(sg.getArchetype().getDelaySteps() + 1u), sg.getTimeType()); const bool preDelay = sg.getArchetype().getSrcNeuronGroup()->isDelayRequired(); - const std::string preSTIndex = sg.getPreVarIndex(preDelay, batchSize, VarAccessDuplication::DUPLICATE, "$(id_pre)"); - const std::string prevPreSTIndex = sg.getPrePrevSpikeTimeIndex(preDelay, batchSize, VarAccessDuplication::DUPLICATE, "$(id_pre)"); + const std::string preSTIndex = sg.getPreVarIndex(preDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::PRE_NEURON, "$(id_pre)"); + const std::string prevPreSTIndex = sg.getPrePrevSpikeTimeIndex(preDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::PRE_NEURON, "$(id_pre)"); synEnv.add(sg.getTimeType().addConst(), "st_pre", "stPre", {synEnv.addInitialiser("const " + timeStr + " stPre = " + axonalDelayMs + " + $(_src_st)[" + preSTIndex + "];")}); synEnv.add(sg.getTimeType().addConst(), "prev_st_pre", "prevSTPre", @@ -67,8 +67,8 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa // Calculate backprop delay to add to (somatic) spike times and substitute in postsynaptic spike times const std::string backPropDelayMs = Type::writeNumeric(dt * (double)(sg.getArchetype().getBackPropDelaySteps() + 1u), sg.getTimeType()); const bool postDelay = sg.getArchetype().getTrgNeuronGroup()->isDelayRequired(); - const std::string postSTIndex = sg.getPostVarIndex(postDelay, batchSize, VarAccessDuplication::DUPLICATE, "$(id_post)"); - const std::string prevPostSTIndex = sg.getPostPrevSpikeTimeIndex(postDelay, batchSize, VarAccessDuplication::DUPLICATE, "$(id_post)"); + const std::string postSTIndex = sg.getPostVarIndex(postDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::POST_NEURON, "$(id_post)"); + const std::string prevPostSTIndex = sg.getPostPrevSpikeTimeIndex(postDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::POST_NEURON, "$(id_post)"); synEnv.add(sg.getTimeType().addConst(), "st_post", "stPost", {synEnv.addInitialiser("const " + timeStr + " stPost = " + backPropDelayMs + " + $(_trg_st)[" + postSTIndex + "];")}); synEnv.add(sg.getTimeType().addConst(), "prev_st_post", "prevSTPost", @@ -80,7 +80,7 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa backend.getDeviceVarPrefix(), [&sg, batchSize](unsigned int a, const std::string&) { - return sg.getSynVarIndex(batchSize, getVarAccessDuplication(a), "$(id_syn)"); + return sg.getSynVarIndex(batchSize, a, "$(id_syn)"); }); } // Otherwise, if weights are procedual @@ -123,7 +123,7 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa backend.getDeviceVarPrefix(), [&sg, batchSize](unsigned int a, const std::string&) { - return sg.getKernelVarIndex(batchSize, getVarAccessDuplication(a), "$(id_kernel)"); + return sg.getKernelVarIndex(batchSize, a, "$(id_kernel)"); }); } @@ -209,75 +209,76 @@ std::string SynapseGroupMergedBase::getPostDenDelayIndex(unsigned int batchSize, } } //---------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getPreVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const +std::string SynapseGroupMergedBase::getPreVarIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const { - return getVarIndex(delay, batchSize, varDuplication, index, "pre"); + return getVarIndex(delay, batchSize, varAccess, VarAccessDim::PRE_NEURON, index, "pre"); } //-------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getPostVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const +std::string SynapseGroupMergedBase::getPostVarIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const { - return getVarIndex(delay, batchSize, varDuplication, index, "post"); + return getVarIndex(delay, batchSize, varAccess, VarAccessDim::POST_NEURON, index, "post"); } //-------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getPrePrevSpikeTimeIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const +std::string SynapseGroupMergedBase::getPrePrevSpikeTimeIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const { - const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); - + const bool batched = ((varAccess & VarAccessDim::BATCH) && batchSize > 1); + if(delay) { - return (singleBatch ? "$(_pre_prev_spike_time_delay_offset) + " : "$(_pre_prev_spike_time_batch_delay_offset) + ") + index; + return (batched ? "$(_pre_prev_spike_time_batch_delay_offset) + " : "$(_pre_prev_spike_time_delay_offset) + " ) + index; } else { - return (singleBatch ? "" : "$(_pre_batch_offset) + ") + index; + return (batched ? "$(_pre_batch_offset) + " : "") + index; } } //-------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getPostPrevSpikeTimeIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const +std::string SynapseGroupMergedBase::getPostPrevSpikeTimeIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const { - const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); + const bool batched = ((varAccess & VarAccessDim::BATCH) && batchSize > 1); if(delay) { - return (singleBatch ? "$(_post_prev_spike_time_delay_offset) + " : "$(_post_prev_spike_time_batch_delay_offset) + ") + index; + return (batched ? "$(_post_prev_spike_time_batch_delay_offset) + " : "$(_post_prev_spike_time_delay_offset) + ") + index; } else { - return (singleBatch ? "" : "$(_post_batch_offset) + ") + index; + return (batched ? "$(_post_batch_offset) + " : "") + index; } } //-------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getSynVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const +std::string SynapseGroupMergedBase::getSynVarIndex(unsigned int batchSize, unsigned int varAccess, const std::string &index) const { - const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); - return (singleBatch ? "" : "$(_syn_batch_offset) + ") + index; + const bool batched = ((varAccess & VarAccessDim::BATCH) && batchSize > 1); + return (batched ? "$(_syn_batch_offset) + " : "") + index; } //-------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getKernelVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const +std::string SynapseGroupMergedBase::getKernelVarIndex(unsigned int batchSize, unsigned int varAccess, const std::string &index) const { - const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); - return (singleBatch ? "" : "$(_kern_batch_offset) + ") + index; + const bool batched = ((varAccess & VarAccessDim::BATCH) && batchSize > 1); + return (batched ? "$(_kern_batch_offset) + " : "") + index; } //---------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, - const std::string &index, const std::string &prefix) const +std::string SynapseGroupMergedBase::getVarIndex(bool delay, unsigned int batchSize, unsigned int varAccess, + VarAccessDim neuronAxis, const std::string &index, const std::string &prefix) const { + const bool batched = ((varAccess & VarAccessDim::BATCH) && batchSize > 1); if (delay) { - if (varDuplication == VarAccessDuplication::SHARED_NEURON) { - return ((batchSize == 1) ? "$(_" + prefix + "_delay_slot)" : "$(_" + prefix + "_batch_delay_slot)"); + if (!(varAccess & neuronAxis)) { + return (batched ? "$(_" + prefix + "_batch_delay_slot)" : "$(_" + prefix + "_delay_slot)"); } - else if (varDuplication == VarAccessDuplication::SHARED || batchSize == 1) { - return "$(_" + prefix + "_delay_offset) + " + index; + else if(batched) { + return "$(_" + prefix + "_batch_delay_offset) + " + index; } else { - return "$(_" + prefix + "_batch_delay_offset) + " + index; + return "$(_" + prefix + "_delay_offset) + " + index; } } else { - if (varDuplication == VarAccessDuplication::SHARED_NEURON) { - return (batchSize == 1) ? "0" : "$(batch)"; + if (!(varAccess & neuronAxis)) { + return batched ? "$(batch)" : "0"; } - else if (varDuplication == VarAccessDuplication::SHARED || batchSize == 1) { - return index; + else if (batched) { + return "$(_" + prefix + "_batch_offset) + " + index; } else { - return "$(_" + prefix + "_batch_offset) + " + index; + return index; } } } diff --git a/src/genn/genn/customConnectivityUpdate.cc b/src/genn/genn/customConnectivityUpdate.cc index 0e41dd8010..e78da94859 100644 --- a/src/genn/genn/customConnectivityUpdate.cc +++ b/src/genn/genn/customConnectivityUpdate.cc @@ -176,11 +176,11 @@ void CustomConnectivityUpdate::finalise(double dt, unsigned int batchSize) // If model is batched we need to check all variable references // are SHARED as, connectivity itself is always SHARED if (batchSize > 1) { - // If any referenced presynaptic variables aren't shared, give error + // If any referenced presynaptic variables are batched, give error if (std::any_of(getPreVarReferences().cbegin(), getPreVarReferences().cend(), [](const auto &v) { - return (getVarAccessDuplication(v.second.getVar().getAccess(VarAccess::READ_WRITE)) != VarAccessDuplication::SHARED); + return (v.second.getVar().getAccess(NeuronVarAccess::READ_WRITE) & VarAccessDim::BATCH); })) { throw std::runtime_error("Presynaptic variables referenced by CustomConnectivityUpdate must be SHARED across batches"); @@ -190,7 +190,7 @@ void CustomConnectivityUpdate::finalise(double dt, unsigned int batchSize) if (std::any_of(getPostVarReferences().cbegin(), getPostVarReferences().cend(), [](const auto &v) { - return (getVarAccessDuplication(v.second.getVar().getAccess(VarAccess::READ_WRITE)) != VarAccessDuplication::SHARED); + return (v.second.getVar().getAccess(NeuronVarAccess::READ_WRITE)& VarAccessDim::BATCH); })) { throw std::runtime_error("Postsynaptic variables referenced by CustomConnectivityUpdate must be SHARED across batches"); diff --git a/src/genn/genn/customConnectivityUpdateModels.cc b/src/genn/genn/customConnectivityUpdateModels.cc index f2274039c2..ff1fa26425 100644 --- a/src/genn/genn/customConnectivityUpdateModels.cc +++ b/src/genn/genn/customConnectivityUpdateModels.cc @@ -75,7 +75,7 @@ void Base::validate(const std::unordered_map ¶mValues, if (std::any_of(vars.cbegin(), vars.cend(), [](const Models::Base::Var &v) { - return (v.getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::SHARED_NEURON); + return (v.getAccess(SynapseVarAccess::READ_WRITE) & VarAccessDim::SHARED_NEURON); })) { throw std::runtime_error("Custom connectivity update models cannot include variables with SHARED_NEURON access modes - they are only supported on pre, postsynaptic or neuron variables"); diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index 76902a785a..6b473f24d5 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -134,26 +134,26 @@ CustomUpdate::CustomUpdate(const std::string &name, const std::string &updateGro // Check variable reference types Models::checkVarReferences(m_VarReferences, getCustomUpdateModel()->getVarRefs()); - // Update is per-neuron if any variables or variable reference targets AREN'T SHARED_NEURON + // Update is per-neuron if any variables or variable reference targets have neuron dimension const auto modelVars = getCustomUpdateModel()->getVars(); m_PerNeuron = std::any_of(m_VarReferences.cbegin(), m_VarReferences.cend(), [](const auto& v) { - return !(v.second.getVar().getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::SHARED_NEURON); + return (v.second.getVar().getAccess(NeuronVarAccess::READ_WRITE) & VarAccessDim::NEURON); }); m_PerNeuron |= std::any_of(modelVars.cbegin(), modelVars.cend(), [](const Models::Base::Var& v) { - return !(v.getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::SHARED_NEURON); + return (v.getAccess(NeuronVarAccess::READ_WRITE) & VarAccessDim::NEURON); }); // Loop through all variable references for(const auto &modelVarRef : getCustomUpdateModel()->getVarRefs()) { const auto &varRef = m_VarReferences.at(modelVarRef.name); - // If custom update is per-neuron, check that any variable references to SHARED_NEURON variables are read-only + // If custom update is per-neuron, check that any variable references to variables without NEURON axis are read-only // **NOTE** if custom update isn't per-neuron, it's totally fine to write to SHARED_NEURON variables - if(m_PerNeuron && (varRef.getVar().getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::SHARED_NEURON) + if(m_PerNeuron && !(varRef.getVar().getAccess(NeuronVarAccess::READ_WRITE) & VarAccessDim::NEURON) && (modelVarRef.access == VarAccessMode::READ_WRITE)) { throw std::runtime_error("Variable references to SHARED_NEURON variables in per-neuron custom updates cannot be read-write."); @@ -218,8 +218,8 @@ boost::uuids::detail::sha1::digest_type CustomUpdate::getHashDigest() const // Update hash with whether variable references require delay Utils::updateHash((v.second.getDelayNeuronGroup() == nullptr), hash); - // Update hash with duplication mode of target variable as this effects indexing code - Utils::updateHash(getVarAccessDuplication(v.second.getVar().getAccess(VarAccess::READ_WRITE)), hash); + // Update hash with target variable access mode as this effects indexing code + Utils::updateHash(v.second.getVar().getAccess(NeuronVarAccess::READ_WRITE), hash); } return hash.get_digest(); } @@ -340,8 +340,8 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWU::getHashDigest() const // Update hash with whether variable references require transpose Utils::updateHash((v.second.getTransposeSynapseGroup() == nullptr), hash); - // Update hash with duplication mode of target variable as this effects indexing code - Utils::updateHash(getVarAccessDuplication(v.second.getVar().getAccess(VarAccess::READ_WRITE)), hash); + // Update hash with access mode of target variable as this effects indexing code + Utils::updateHash(v.second.getVar().getAccess(SynapseVarAccess::READ_WRITE), hash); } return hash.get_digest(); diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index 3790ae7b8f..fe8cc9ec6d 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -75,10 +75,8 @@ std::string VarReference::getTargetName() const //---------------------------------------------------------------------------- bool VarReference::isDuplicated() const { - if(getVar().getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::SHARED) { - return false; - } - else { + // If target variable has BATCH dimension + if(getVar().getAccess(NeuronVarAccess::READ_WRITE) & VarAccessDim::BATCH) { return std::visit( Utils::Overload{ [](const CURef &ref) { return ref.group->isBatched(); }, @@ -87,6 +85,9 @@ bool VarReference::isDuplicated() const [](const auto&) { return true; }}, m_Detail); } + else { + return false; + } } //---------------------------------------------------------------------------- CustomUpdate *VarReference::getReferencedCustomUpdate() const @@ -175,10 +176,8 @@ std::string WUVarReference::getTargetName() const //---------------------------------------------------------------------------- bool WUVarReference::isDuplicated() const { - if(getVar().getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::SHARED) { - return false; - } - else { + // If target variable has BATCH dimension + if(getVar().getAccess(SynapseVarAccess::READ_WRITE) & VarAccessDim::BATCH) { return std::visit( Utils::Overload{ [](const CURef &ref) { return ref.group->isBatched(); }, @@ -186,6 +185,9 @@ bool WUVarReference::isDuplicated() const [](const WURef&) { return true; }}, m_Detail); } + else { + return false; + } } //---------------------------------------------------------------------------- SynapseGroup *WUVarReference::getSynapseGroup() const @@ -330,8 +332,8 @@ WUVarReference::WUVarReference(size_t varIndex, const Models::Base::VarVec &varV } // Check duplicatedness of variables - if((getVar().getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::DUPLICATE) - != (getTransposeVar().getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::DUPLICATE)) + if((getVar().getAccess(SynapseVarAccess::READ_WRITE) & VarAccessDim::BATCH) + != (getTransposeVar().getAccess(SynapseVarAccess::READ_WRITE) & VarAccessDim::BATCH)) { throw std::runtime_error("Transpose updates can only be performed on similarly batched variables"); } From f875d9474b778ba8713383f369479931ecd7f01a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 18 Aug 2023 12:59:34 +0100 Subject: [PATCH 05/60] More iteration on this: * VarAccess is now a wrapper class around a std::variant holding NeuronVarAccess, SynapseVarAccess and std::monostate (for the defaults) * Mostly functional aside from custom updates --- include/genn/backends/cuda/backend.h | 4 +- .../backends/single_threaded_cpu/backend.h | 7 +- .../genn/genn/code_generator/backendBase.h | 14 +-- .../customConnectivityUpdateGroupMerged.h | 2 +- .../code_generator/customUpdateGroupMerged.h | 12 +-- .../genn/genn/code_generator/environment.h | 29 +++--- .../code_generator/neuronUpdateGroupMerged.h | 6 +- .../code_generator/synapseUpdateGroupMerged.h | 39 +++++--- include/genn/genn/customUpdate.h | 23 ++--- include/genn/genn/models.h | 38 +------- include/genn/genn/neuronModels.h | 6 +- include/genn/genn/varAccess.h | 97 ++++++++++++++----- include/genn/genn/weightUpdateModels.h | 6 +- .../backends/single_threaded_cpu/backend.cc | 26 ++--- src/genn/genn/code_generator/backendBase.cc | 26 ++--- src/genn/genn/code_generator/backendSIMT.cc | 5 +- .../customConnectivityUpdateGroupMerged.cc | 4 +- .../code_generator/customUpdateGroupMerged.cc | 36 ++++--- .../genn/code_generator/generateRunner.cc | 49 +++++----- .../genn/code_generator/initGroupMerged.cc | 22 ++--- .../code_generator/neuronUpdateGroupMerged.cc | 81 ++++++++-------- .../synapseUpdateGroupMerged.cc | 54 +++++------ src/genn/genn/currentSourceModels.cc | 7 +- src/genn/genn/customConnectivityUpdate.cc | 4 +- .../genn/customConnectivityUpdateModels.cc | 31 +++--- src/genn/genn/customUpdate.cc | 22 ++--- src/genn/genn/models.cc | 10 +- src/genn/genn/neuronModels.cc | 4 +- src/genn/genn/postsynapticModels.cc | 5 +- src/genn/genn/weightUpdateModels.cc | 41 ++++---- 30 files changed, 357 insertions(+), 353 deletions(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index b74ad8e9d0..3552ee9f63 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -383,7 +383,7 @@ class BACKEND_EXPORT Backend : public BackendSIMT const auto *cm = cg.getArchetype().getCustomUpdateModel(); for(const auto &v : cm->getVars()) { // If variable is reduction target - if(v.getAccessMode() & VarAccessModeAttribute::REDUCE) { + if(v.access & VarAccessModeAttribute::REDUCE) { // Add pointer field const auto resolvedType = v.type.resolve(cg.getTypeContext()); groupEnv.addField(resolvedType.createPointer(), "_" + v.name, v.name, @@ -394,7 +394,7 @@ class BACKEND_EXPORT Backend : public BackendSIMT // Add NCCL reduction groupEnv.print("CHECK_NCCL_ERRORS(ncclAllReduce($(_" + v.name + "), $(_" + v.name + "), $(_size)"); - groupEnv.printLine(", " + getNCCLType(resolvedType) + ", " + getNCCLReductionType(v.getAccessMode()) + ", ncclCommunicator, 0));"); + groupEnv.printLine(", " + getNCCLType(resolvedType) + ", " + getNCCLReductionType(v.access) + ", ncclCommunicator, 0));"); } } diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 534e6ced0b..e79b961be7 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -246,16 +246,15 @@ class BACKEND_EXPORT Backend : public BackendBase /*! Because reduction operations are unnecessary in unbatched single-threaded CPU models so there's no need to actually reduce */ void genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateWUGroupMergedBase &cg, const std::string &idxName) const; - template + template void genWriteBackReductions(EnvironmentExternalBase &env, G &cg, const std::string &idxName, R getVarRefIndexFn) const { const auto *cm = cg.getArchetype().getCustomUpdateModel(); for(const auto &v : cm->getVars()) { // If variable is a reduction target, copy value from register straight back into global memory - const unsigned int varAccess = v.getAccess(VarAccess::READ_WRITE); - if(varAccess & VarAccessModeAttribute::REDUCE) { + if(v.access & VarAccessModeAttribute::REDUCE) { const std::string idx = env.getName(idxName); - env.getStream() << "group->" << v.name << "[" << cg.getVarIndex(getVarAccessDim(varAccess), idx) << "] = " << env[v.name] << ";" << std::endl; + env.getStream() << "group->" << v.name << "[" << cg.getVarIndex(v.access.getDims(), idx) << "] = " << env[v.name] << ";" << std::endl; } } diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index ccf9880b37..46f1c9fe04 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -560,20 +560,20 @@ class GENN_EXPORT BackendBase //-------------------------------------------------------------------------- // Private API //-------------------------------------------------------------------------- - template - std::vector genInitReductionTargets(CodeStream &os, const G &cg, const std::string &idx, R getVarRefIndexFn) const + template + std::vector genInitReductionTargets(CodeStream &os, const G &cg, const std::string &idx, + R getVarRefIndexFn) const { // Loop through variables std::vector reductionTargets; const auto *cm = cg.getArchetype().getCustomUpdateModel(); for (const auto &v : cm->getVars()) { // If variable is a reduction target, define variable initialised to correct initial value for reduction - const unsigned int varAccess = v.getAccess(VarAccess::READ_WRITE); - if (varAccess & VarAccessModeAttribute::REDUCE) { + if (v.access & VarAccessModeAttribute::REDUCE) { const auto resolvedType = v.type.resolve(cg.getTypeContext()); - os << resolvedType.getName() << " _lr" << v.name << " = " << getReductionInitialValue(getVarAccessMode(varAccess), resolvedType) << ";" << std::endl; - reductionTargets.push_back({v.name, resolvedType, getVarAccessMode(varAccess), - cg.getVarIndex(varAccess), idx)}); + os << resolvedType.getName() << " _lr" << v.name << " = " << getReductionInitialValue(v.access, resolvedType) << ";" << std::endl; + reductionTargets.push_back({v.name, resolvedType, v.access, + cg.getVarIndex(v.access.getDims(), idx)}); } } diff --git a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h index 616bbd3e58..36a10362ee 100644 --- a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h @@ -129,7 +129,7 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public GroupMerged genPostamble); - std::string getVarIndex(unsigned int varAccess, const std::string &index) const; - std::string getVarRefIndex(bool delay, unsigned int varAccess, const std::string &index) const; + std::string getVarIndex(VarAccessDim varDims, const std::string &index) const; + std::string getVarRefIndex(bool delay, VarAccessDim varDims, const std::string &index) const; //---------------------------------------------------------------------------- // Static constants @@ -67,8 +67,8 @@ class GENN_EXPORT CustomUpdateWUGroupMergedBase : public GroupMerged genPostamble); - std::string getVarIndex(unsigned int varAccess, const std::string &index) const; - std::string getVarRefIndex(unsigned int varAccess, const std::string &index) const; + std::string getVarIndex(VarAccessDim varDims, const std::string &index) const; + std::string getVarRefIndex(VarAccessDim varDims, const std::string &index) const; }; @@ -141,7 +141,7 @@ class CustomUpdateHostReductionGroupMergedBase : public GroupMerged // Loop through variables and add pointers if they are reduction targets const auto *cm = this->getArchetype().getCustomUpdateModel(); for(const auto &v : cm->getVars()) { - if(v.getAccessMode() & VarAccessModeAttribute::REDUCE) { + if(v.access & VarAccessModeAttribute::REDUCE) { const auto fieldType = v.type.resolve(this->getTypeContext()).createPointer(); env.addField(fieldType, v.name, v.name, [&backend, v](const auto &g, size_t) @@ -153,7 +153,7 @@ class CustomUpdateHostReductionGroupMergedBase : public GroupMerged // Loop through variable references and add pointers if they are reduction targets for(const auto &v : cm->getVarRefs()) { - if(v.getAccessMode() & VarAccessModeAttribute::REDUCE) { + if(v.access & VarAccessModeAttribute::REDUCE) { const auto fieldType = v.type.resolve(this->getTypeContext()).createPointer(); env.addField(fieldType, v.name, v.name, [&backend, v](const auto &g, size_t) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 3ef7d5b24d..bc1bd7a581 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -405,7 +405,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase &(GroupInternal::*)(void) const; - using GetVarIndexFn = std::function; + using GetVarIndexFn = std::function; template using GetVarRefIndexFn = std::function; @@ -639,14 +639,14 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetGroup().getArchetype()); for(const auto &v : archetypeAdaptor.getDefs()) { const auto resolvedType = v.type.resolve(this->getGroup().getTypeContext()); - const auto qualifiedType = (readOnly || (v.getAccessMode() & VarAccessModeAttribute::READ_ONLY)) ? resolvedType.addConst() : resolvedType; + const auto qualifiedType = (readOnly || (v.access & VarAccessModeAttribute::READ_ONLY)) ? resolvedType.addConst() : resolvedType; addField(qualifiedType, v.name, resolvedType.createPointer(), v.name + fieldSuffix, [arrayPrefix, v](const auto &g, size_t) { return arrayPrefix + v.name + A(g).getNameSuffix(); }, - getIndexFn(v.getAccess(VarAccess::READ_WRITE), v.name)); + getIndexFn(v.access, v.name)); } } @@ -654,7 +654,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase(arrayPrefix, [&indexSuffix](unsigned int, const std::string &) { return indexSuffix; }, + addVars(arrayPrefix, [&indexSuffix](VarAccess, const std::string &) { return indexSuffix; }, fieldSuffix, readOnly); } @@ -704,8 +704,8 @@ class VarCachePolicy { public: using GroupInternal = typename G::GroupInternal; - using GetIndexFn = std::function; - using ShouldAlwaysCopyFn = std::function; + using GetIndexFn = std::function; + using ShouldAlwaysCopyFn = std::function; VarCachePolicy(GetIndexFn getReadIndex, GetIndexFn getWriteIndex, ShouldAlwaysCopyFn shouldAlwaysCopy = ShouldAlwaysCopyFn()) @@ -723,20 +723,17 @@ class VarCachePolicy //------------------------------------------------------------------------ bool shouldAlwaysCopy(G&, const Models::Base::Var &var) const { - // **TODO** default from InitModel class - return m_ShouldAlwaysCopy(var.name, getVarAccessDuplication(var.getAccess(VarAccess::READ_WRITE))); + return m_ShouldAlwaysCopy(var.name, var.access); } std::string getReadIndex(G&, const Models::Base::Var &var) const { - // **TODO** default from InitModel class - return m_GetReadIndex(var.name, getVarAccessDuplication(var.getAccess(VarAccess::READ_WRITE))); + return m_GetReadIndex(var.name, var.access); } std::string getWriteIndex(G&, const Models::Base::Var &var) const { - // **TODO** default from InitModel class - return m_GetWriteIndex(var.name, getVarAccessDuplication(var.getAccess(VarAccess::READ_WRITE))); + return m_GetWriteIndex(var.name, var.access); } std::string getTargetName(const GroupInternal &g, const Models::Base::Var &var) const @@ -860,7 +857,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P return arrayPrefix + this->getTargetName(group.getGroups().at(i), v); }); - if(v.getAccessMode() & VarAccessMode::READ_ONLY) { + if(v.access == VarAccessMode::READ_ONLY) { getContextStream() << "const "; } getContextStream() << resolvedType.getName() << " _" << m_LocalPrefix << v.name; @@ -868,7 +865,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P // If this isn't a reduction, read value from memory // **NOTE** by not initialising these variables for reductions, // compilers SHOULD emit a warning if user code doesn't set it to something - if(!(v.getAccessMode() & VarAccessModeAttribute::REDUCE)) { + if(!(v.access & VarAccessModeAttribute::REDUCE)) { getContextStream() << " = group->" << v.name << m_FieldSuffix << "[" << printSubs(this->getReadIndex(m_Group.get(), v), *this) << "]"; } getContextStream() << ";" << std::endl; @@ -880,7 +877,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P // Loop through referenced definitions again for(const auto &v : referencedDefs) { // If we should always copy variable or variable is read-write - if(this->shouldAlwaysCopy(m_Group.get(), v) || v.getAccessMode() & VarAccessMode::READ_WRITE) { + if(this->shouldAlwaysCopy(m_Group.get(), v) || (v.access == VarAccessMode::READ_WRITE)) { getContextStream() << "group->" << v.name << m_FieldSuffix << "[" << printSubs(this->getWriteIndex(m_Group.get(), v), *this) << "]"; getContextStream() << " = _" << m_LocalPrefix << v.name << ";" << std::endl; } @@ -904,7 +901,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P // Resolve type, add qualifier if required and return const auto resolvedType = var->second.second.type.resolve(m_Context.get()); - const auto qualifiedType = (var->second.second.getAccessMode() & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; + const auto qualifiedType = (var->second.second.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; return {qualifiedType}; } } diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index a9bd8edd4a..275488c50a 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -185,9 +185,9 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase void generateWUVarUpdate(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize); - std::string getVarIndex(unsigned int batchSize, unsigned int varAccess, const std::string &index) const; - std::string getReadVarIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const; - std::string getWriteVarIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const; + std::string getVarIndex(unsigned int batchSize, VarAccessDim varDims, const std::string &index) const; + std::string getReadVarIndex(bool delay, unsigned int batchSize, VarAccessDim varDims, const std::string &index) const; + std::string getWriteVarIndex(bool delay, unsigned int batchSize, VarAccessDim varDims, const std::string &index) const; const std::vector &getMergedCurrentSourceGroups() const { return m_MergedCurrentSourceGroups; } const std::vector &getMergedInSynPSMGroups() const { return m_MergedInSynPSMGroups; } diff --git a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h index 1920602071..ed523bd9aa 100644 --- a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h @@ -44,33 +44,40 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMergedisDelayRequired(), batchSize, varAccess, index); + return getPreVarIndex(getArchetype().getSrcNeuronGroup()->isDelayRequired(), batchSize, varDims, index); } - std::string getPostVarIndex(unsigned int batchSize, unsigned int varAccess, const std::string &index) const + std::string getPostVarIndex(unsigned int batchSize, VarAccessDim varDims, const std::string &index) const { - return getPostVarIndex(getArchetype().getTrgNeuronGroup()->isDelayRequired(), batchSize, varAccess, index); + return getPostVarIndex(getArchetype().getTrgNeuronGroup()->isDelayRequired(), batchSize, varDims, index); } - std::string getPreWUVarIndex(unsigned int batchSize, unsigned int varAccess, const std::string &index) const + std::string getPreWUVarIndex(unsigned int batchSize, VarAccessDim varDims, const std::string &index) const { - return getPreVarIndex(getArchetype().getDelaySteps() != 0, batchSize, varAccess, index); + return getPreVarIndex(getArchetype().getDelaySteps() != 0, batchSize, varDims, index); } - std::string getPostWUVarIndex(unsigned int batchSize, unsigned int varAccess, const std::string &index) const + std::string getPostWUVarIndex(unsigned int batchSize, VarAccessDim varDims, const std::string &index) const { - return getPostVarIndex(getArchetype().getBackPropDelaySteps() != 0, batchSize, varAccess, index); + return getPostVarIndex(getArchetype().getBackPropDelaySteps() != 0, batchSize, varDims, index); } std::string getPostDenDelayIndex(unsigned int batchSize, const std::string &index, const std::string &offset) const; - std::string getPreVarIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const; - std::string getPostVarIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const; + std::string getPreVarIndex(bool delay, unsigned int batchSize, VarAccessDim varDims, const std::string &index) const + { + return getPrePostVarIndex(delay, batchSize, varDims, index, "pre"); + } + + std::string getPostVarIndex(bool delay, unsigned int batchSize, VarAccessDim varDims, const std::string &index) const + { + return getPrePostVarIndex(delay, batchSize, varDims, index, "post"); + } - std::string getPrePrevSpikeTimeIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const; - std::string getPostPrevSpikeTimeIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const; + std::string getPrePrevSpikeTimeIndex(bool delay, unsigned int batchSize, VarAccessDim varDims, const std::string &index) const; + std::string getPostPrevSpikeTimeIndex(bool delay, unsigned int batchSize, VarAccessDim varDims, const std::string &index) const; std::string getPostISynIndex(unsigned int batchSize, const std::string &index) const { @@ -82,8 +89,8 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged getUpdateCodeTokens() const{ return m_UpdateCodeTokens; } - template - bool isReduction(const std::unordered_map &varRefs, VarAccessDim reduceDim) const + template + bool isReduction(const std::unordered_map &varRefs, + VarAccessDim reduceDim) const { - // Return true if any variables have REDUCE flag in their access mode and doesn't have reduction dimension + // Return true if any variables have REDUCE flag in their access mode and don't have reduction dimension const auto vars = getCustomUpdateModel()->getVars(); if(std::any_of(vars.cbegin(), vars.cend(), [reduceDim](const Models::Base::Var &v) { - const unsigned int access = v.getAccess(VarAccess::READ_WRITE); - return (access & VarAccessModeAttribute::REDUCE) && !(access & reduceDim); + return ((v.access & VarAccessModeAttribute::REDUCE) + && !(v.access.getDims() & reduceDim)); })) { return true; @@ -106,7 +107,7 @@ class GENN_EXPORT CustomUpdateBase // and the variable it targets doesn't have reduction dimension const auto &varRef = varRefs.at(modelVarRef.name); if ((modelVarRef.access & VarAccessModeAttribute::REDUCE) - && !(varRef.getVar().getAccess(VarAccess::READ_WRITE) & reduceDim)) + && !(varRef.getVar().access.getDims() & reduceDim)) { return true; } @@ -116,7 +117,7 @@ class GENN_EXPORT CustomUpdateBase } //! Helper function to check if variable reference types match those specified in model - template + template void checkVarReferenceBatching(const std::unordered_map& varRefs, unsigned int batchSize) { // If target of any variable references is duplicated, custom update should be batched @@ -134,7 +135,7 @@ class GENN_EXPORT CustomUpdateBase // If custom update is batched, check that any variable references to variables that aren't batched are read-only // **NOTE** if custom update isn't batched, it's totally fine to write to shared variables - if(m_Batched && !(varRef.getVar().getAccess(VarAccess::READ_WRITE) & VarAccessDim::BATCH) + if(m_Batched && !(varRef.getVar().access.getDims() & VarAccessDim::BATCH) && (modelVarRef.access == VarAccessMode::READ_WRITE)) { throw std::runtime_error("Variable references to non-batched variables in batched custom updates cannot be read-write."); @@ -265,8 +266,8 @@ class GENN_EXPORT CustomUpdate : public CustomUpdateBase //------------------------------------------------------------------------ // Protected const methods //------------------------------------------------------------------------ - bool isBatchReduction() const { return isReduction(getVarReferences(), VarAccessDim::BATCH); } - bool isNeuronReduction() const { return isReduction(getVarReferences(), VarAccessDim::NEURON); } + bool isBatchReduction() const { return isReduction(getVarReferences(), VarAccessDim::BATCH); } + bool isNeuronReduction() const { return isReduction(getVarReferences(), VarAccessDim::NEURON); } bool isPerNeuron() const{ return m_PerNeuron; } const NeuronGroup *getDelayNeuronGroup() const { return m_DelayNeuronGroup; } @@ -322,7 +323,7 @@ class GENN_EXPORT CustomUpdateWU : public CustomUpdateBase //------------------------------------------------------------------------ // Protected const methods //------------------------------------------------------------------------ - bool isBatchReduction() const { return isReduction(getVarReferences(), VarAccessDim::BATCH); } + bool isBatchReduction() const { return isReduction(getVarReferences(), VarAccessDim::BATCH); } bool isTransposeOperation() const; SynapseGroupInternal *getSynapseGroup() const { return m_SynapseGroup; } diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index cddb3db408..a002f27e6a 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -62,42 +62,20 @@ class GENN_EXPORT Base : public Snippet::Base {} Var(const std::string &n, const Type::ResolvedType &t, VarAccess a) - : name(n), type(t), access(static_cast(a)) + : name(n), type(t), access(a) {} Var(const std::string &n, const std::string &t, VarAccess a) - : name(n), type(t), access(static_cast(a)) + : name(n), type(t), access(a) {} - - /*Var(const std::string &n, const Type::ResolvedType &t, NeuronVarAccess a) - : name(n), type(t), access(static_cast(a)) - {} - Var(const std::string &n, const std::string &t, NeuronVarAccess a) - : name(n), type(t), access(static_cast(a)) - {}*/ bool operator == (const Var &other) const { return (std::tie(name, type, access) == std::tie(other.name, other.type, other.access)); } - unsigned int getAccess(VarAccess defaultAccess) const - { - return access.value_or(static_cast(defaultAccess)); - } - - VarAccessMode getAccessMode() const - { - if(access) { - return getVarAccessMode(access.value()); - } - else { - return VarAccessMode::READ_WRITE; - } - } - std::string name; Type::UnresolvedType type; - std::optional access; + VarAccess access; }; struct GENN_EXPORT VarRef @@ -112,11 +90,6 @@ class GENN_EXPORT Base : public Snippet::Base return (std::tie(name, type, access) == std::tie(other.name, other.type, other.access)); } - VarAccessMode getAccessMode() const - { - return access; - } - std::string name; Type::UnresolvedType type; VarAccessMode access; @@ -425,12 +398,11 @@ void checkVarReferences(const std::unordered_map &varRefs, const } // Check that no reduction targets reference duplicated variables - // **TODO** default from InitModel class - if((varRef.getVar().getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::DUPLICATE) + /*if((varRef.getVar().access.getDims() & VarAccessDuplication::DUPLICATE) && (modelVarRef.access & VarAccessModeAttribute::REDUCE)) { throw std::runtime_error("Reduction target variable reference must be to SHARED or SHARED_NEURON variables."); - } + }*/ } } } // GeNN::Models diff --git a/include/genn/genn/neuronModels.h b/include/genn/genn/neuronModels.h index 9e28a9a60a..a2bd9b6b52 100644 --- a/include/genn/genn/neuronModels.h +++ b/include/genn/genn/neuronModels.h @@ -197,8 +197,8 @@ class IzhikevichVariable : public Izhikevich SET_PARAM_NAMES({}); SET_VARS({{"V","scalar"}, {"U", "scalar"}, - {"a", "scalar", VarAccess::READ_ONLY}, {"b", "scalar", VarAccess::READ_ONLY}, - {"c", "scalar", VarAccess::READ_ONLY}, {"d", "scalar", VarAccess::READ_ONLY}}); + {"a", "scalar", NeuronVarAccess::READ_ONLY}, {"b", "scalar", NeuronVarAccess::READ_ONLY}, + {"c", "scalar", NeuronVarAccess::READ_ONLY}, {"d", "scalar", NeuronVarAccess::READ_ONLY}}); }; //---------------------------------------------------------------------------- @@ -282,7 +282,7 @@ class SpikeSourceArray : public Base "$(startSpike) != $(endSpike) && " "$(t) >= $(spikeTimes)[$(startSpike)]" ); SET_RESET_CODE( "$(startSpike)++;\n" ); - SET_VARS( {{"startSpike", "unsigned int"}, {"endSpike", "unsigned int", VarAccess::READ_ONLY_DUPLICATE}} ); + SET_VARS( {{"startSpike", "unsigned int"}, {"endSpike", "unsigned int", NeuronVarAccess::READ_ONLY_DUPLICATE}} ); SET_EXTRA_GLOBAL_PARAMS( {{"spikeTimes", "scalar*"}} ); SET_NEEDS_AUTO_REFRACTORY(false); }; diff --git a/include/genn/genn/varAccess.h b/include/genn/genn/varAccess.h index 8c17395a83..5a3ffc50f1 100644 --- a/include/genn/genn/varAccess.h +++ b/include/genn/genn/varAccess.h @@ -1,5 +1,12 @@ #pragma once +// Standard C++ includes +#include + +// GeNN includes +#include "gennExport.h" +#include "gennUtils.h" + //---------------------------------------------------------------------------- // Enumerations //---------------------------------------------------------------------------- @@ -89,42 +96,86 @@ enum class SynapseVarAccess : unsigned int //---------------------------------------------------------------------------- // Operators //---------------------------------------------------------------------------- -inline bool operator & (unsigned int type, VarAccessMode mode) -{ - return (type & static_cast(mode)) != 0; -} - -inline bool operator & (unsigned int type, VarAccessDim dim) -{ - return (type & static_cast(dim)) != 0; -} - -inline bool operator & (unsigned int type, VarAccessModeAttribute modeAttribute) -{ - return (type & static_cast(modeAttribute)) != 0; -} - inline bool operator & (VarAccessMode mode, VarAccessModeAttribute modeAttribute) { return (static_cast(mode) & static_cast(modeAttribute)) != 0; } -inline bool operator & (VarAccessMode a, VarAccessMode b) +inline bool operator & (VarAccessDim a, VarAccessDim b) { return (static_cast(a) & static_cast(b)) != 0; } -inline unsigned int operator | (VarAccessDim a, VarAccessDim b) +inline VarAccessDim operator | (VarAccessDim a, VarAccessDim b) { - return (static_cast(a) | static_cast(b)); + return static_cast(static_cast(a) | static_cast(b)); } - //---------------------------------------------------------------------------- -// Helpers +// VarAccess //---------------------------------------------------------------------------- -inline VarAccessMode getVarAccessMode(unsigned int type) +//! Wrapper class encapsulating +GENN_EXPORT class VarAccess { - return static_cast(type & 0x1F); -} +public: + VarAccess() + {} + VarAccess(NeuronVarAccess n) : m_Access{n} + {} + VarAccess(SynapseVarAccess s) : m_Access{s} + {} + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + template + VarAccessDim getDims() const + { + const unsigned int val = std::visit( + Utils::Overload{ + [](std::monostate) { return static_cast(V::READ_WRITE); }, + [](V v) { return static_cast(v); }, + [](auto)->unsigned int { throw std::runtime_error("Invalid var access type"); }}, + m_Access); + + // Mask out dimension bits and cast to enum + return static_cast(val & ~0x1F); + } + + //! Returns true if this VarAccess would be valid for a neuron + bool isValidNeuron() const + { + return !std::holds_alternative(m_Access); + } + + //! Returns true if this VarAccess would be valid for a synapse + bool isValidSynapse() const + { + return !std::holds_alternative(m_Access); + } + + void updateHash(boost::uuids::detail::sha1 &hash) const + { + Utils::updateHash(m_Access, hash); + } + + //------------------------------------------------------------------------ + // Operators + //------------------------------------------------------------------------ + operator VarAccessMode() const + { + return std::visit( + Utils::Overload{ + [](std::monostate) { return VarAccessMode::READ_WRITE; }, + [](auto v) { return static_cast(static_cast(v) & 0x1F); }}, + m_Access); + } + +private: + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + std::variant m_Access; +}; + } // namespace GeNN diff --git a/include/genn/genn/weightUpdateModels.h b/include/genn/genn/weightUpdateModels.h index 36a0cc7c1e..8837011e67 100644 --- a/include/genn/genn/weightUpdateModels.h +++ b/include/genn/genn/weightUpdateModels.h @@ -131,7 +131,7 @@ class StaticPulse : public Base public: DECLARE_SNIPPET(StaticPulse); - SET_VARS({{"g", "scalar", VarAccess::READ_ONLY}}); + SET_VARS({{"g", "scalar", SynapseVarAccess::READ_ONLY}}); SET_SIM_CODE("addToPost(g);\n"); }; @@ -182,7 +182,7 @@ class StaticPulseDendriticDelay : public Base public: DECLARE_SNIPPET(StaticPulseDendriticDelay); - SET_VARS({{"g", "scalar", VarAccess::READ_ONLY}, {"d", "uint8_t", VarAccess::READ_ONLY}}); + SET_VARS({{"g", "scalar", SynapseVarAccess::READ_ONLY}, {"d", "uint8_t", SynapseVarAccess::READ_ONLY}}); SET_SIM_CODE("addToPostDelay(g, d);\n"); }; @@ -219,7 +219,7 @@ class StaticGraded : public Base DECLARE_SNIPPET(StaticGraded); SET_PARAM_NAMES({"Epre", "Vslope"}); - SET_VARS({{"g", "scalar", VarAccess::READ_ONLY}}); + SET_VARS({{"g", "scalar", SynapseVarAccess::READ_ONLY}}); SET_EVENT_CODE("addToPost(fmax(0.0, g * tanh((V_pre - Epre) / Vslope) * DT));\n"); diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 4d62c764d1..01b6331a22 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -2013,22 +2013,22 @@ void Backend::genEmitSpike(EnvironmentExternalBase &env, NeuronUpdateGroupMerged //-------------------------------------------------------------------------- void Backend::genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateGroupMerged &cg, const std::string &idxName) const { - genWriteBackReductions(env, cg, idxName, - [&cg](const Models::VarReference &varRef, const std::string &index) - { - return cg.getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, - varRef.getVar().getAccess(NeuronVarAccess::READ_WRITE), - index); - }); + genWriteBackReductions( + env, cg, idxName, + [&cg](const Models::VarReference &varRef, const std::string &index) + { + return cg.getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, + varRef.getVar().access.getDims(), index); + }); } //-------------------------------------------------------------------------- void Backend::genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateWUGroupMergedBase &cg, const std::string &idxName) const { - genWriteBackReductions(env, cg, idxName, - [&cg](const Models::WUVarReference &varRef, const std::string &index) - { - return cg.getVarRefIndex(varRef.getVar().getAccess(SynapseVarAccess::READ_WRITE), - index); - }); + genWriteBackReductions( + env, cg, idxName, + [&cg](const Models::WUVarReference &varRef, const std::string &index) + { + return cg.getVarRefIndex(varRef.getVar().access.getDims(), index); + }); } } // namespace GeNN::CodeGenerator::SingleThreadedCPU diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 7ef2780dc9..79eb63f4f4 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -699,22 +699,22 @@ std::string BackendBase::getReductionOperation(const std::string &reduction, con //----------------------------------------------------------------------- std::vector BackendBase::genInitReductionTargets(CodeStream &os, const CustomUpdateGroupMerged &cg, const std::string &idx) const { - return genInitReductionTargets(os, cg, idx, - [&cg](const Models::VarReference &varRef, const std::string &index) - { - return cg.getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, - varRef.getVar().getAccess(NeuronVarAccess::READ_WRITE), - index); - }); + return genInitReductionTargets( + os, cg, idx, + [&cg](const Models::VarReference &varRef, const std::string &index) + { + return cg.getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, + varRef.getVar().access.getDims(), index); + }); } //----------------------------------------------------------------------- std::vector BackendBase::genInitReductionTargets(CodeStream &os, const CustomUpdateWUGroupMerged &cg, const std::string &idx) const { - return genInitReductionTargets(os, cg, idx, - [&cg](const Models::WUVarReference &varRef, const std::string &index) - { - return cg.getVarRefIndex(varRef.getVar().getAccess(SynapseVarAccess::READ_WRITE), - index); - }); + return genInitReductionTargets( + os, cg, idx, + [&cg](const Models::WUVarReference &varRef, const std::string &index) + { + return cg.getVarRefIndex(varRef.getVar().access.getDims(), index); + }); } } // namespace GeNN::CodeGenerator \ No newline at end of file diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 03c1cbe1de..fc52106e0b 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -587,7 +587,8 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM } // Copy spikes into block of $(_spk) - const std::string queueOffset = ng.getWriteVarIndex(ng.getArchetype().isDelayRequired(), batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, ""); + const std::string queueOffset = ng.getWriteVarIndex(ng.getArchetype().isDelayRequired(), batchSize, + VarAccessDim::BATCH | VarAccessDim::NEURON, ""); if(!Utils::areTokensEmpty(ng.getArchetype().getThresholdConditionCodeTokens())) { const std::string queueOffsetTrueSpk = ng.getWriteVarIndex(ng.getArchetype().isTrueSpikeRequired() && ng.getArchetype().isDelayRequired(), batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, ""); @@ -805,7 +806,7 @@ void BackendSIMT::genPostsynapticUpdateKernel(EnvironmentExternalBase &env, Mode { CodeStream::Scope b(groupEnv.getStream()); const std::string index = "(r * " + std::to_string(getKernelBlockSize(KernelPostsynapticUpdate)) + ") + " + getThreadID(); - groupEnv.printLine("const unsigned int spk = $(_trg_spk)[" + sg.getPostVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::POST_NEURON, index) + "];"); + groupEnv.printLine("const unsigned int spk = $(_trg_spk)[" + sg.getPostVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, index) + "];"); groupEnv.getStream() << "shSpk[" << getThreadID() << "] = spk;" << std::endl; if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 1871a8679b..6ac40c94d6 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -37,11 +37,11 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t { boost::uuids::detail::sha1 hashA; Type::updateHash(a.getVar().type, hashA); - Utils::updateHash(a.getVar().getAccess(VarAccess::READ_WRITE), hashA); + Utils::updateHash(a.getVar().access.getDims(), hashA); boost::uuids::detail::sha1 hashB; Type::updateHash(b.getVar().type, hashB); - Utils::updateHash(b.getVar().getAccess(VarAccess::READ_WRITE), hashB); + Utils::updateHash(b.getVar().access.getDims(), hashB); return (hashA.get_digest() < hashB.get_digest()); }); diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index f26ea1d336..a14b1ffee0 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -60,9 +60,9 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", - [this, &cuEnv](const std::string&, unsigned int d) + [this, &cuEnv](const std::string&, VarAccess d) { - return getVarIndex(d, "$(id)"); + return getVarIndex(d.getDims(), "$(id)"); }); // Create an environment which caches variable references in local variables if they are accessed @@ -71,8 +71,7 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E [this, &varEnv](const std::string&, const Models::VarReference &v) { return getVarRefIndex(v.getDelayNeuronGroup() != nullptr, - v.getVar().getAccess(NeuronVarAccess::READ_WRITE), - "$(id)"); + v.getVar().access.getDims(), "$(id)"); }); Transpiler::ErrorHandler errorHandler("Custom update '" + getArchetype().getName() + "' update code"); @@ -82,11 +81,11 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E genPostamble(varRefEnv, *this); } //---------------------------------------------------------------------------- -std::string CustomUpdateGroupMerged::getVarIndex(unsigned int varAccess, const std::string &index) const +std::string CustomUpdateGroupMerged::getVarIndex(VarAccessDim varDims, const std::string &index) const { // **YUCK** there's a lot of duplication in these methods - do they belong elsewhere? - const bool batched = (varAccess & VarAccessDim::BATCH) && getArchetype().isBatched(); - if (!(varAccess & VarAccessDim::NEURON)) { + const bool batched = (varDims & VarAccessDim::BATCH) && getArchetype().isBatched(); + if (!(varDims & VarAccessDim::NEURON)) { return batched ? "$(batch)" : "0"; } else if (batched) { @@ -99,12 +98,12 @@ std::string CustomUpdateGroupMerged::getVarIndex(unsigned int varAccess, const s } } //---------------------------------------------------------------------------- -std::string CustomUpdateGroupMerged::getVarRefIndex(bool delay, unsigned int varAccess, const std::string &index) const +std::string CustomUpdateGroupMerged::getVarRefIndex(bool delay, VarAccessDim varDims, const std::string &index) const { // If delayed, variable is shared, the batch size is one or this custom update isn't batched, batch delay offset isn't required if(delay) { - const bool batched = (varAccess & VarAccessDim::BATCH) && getArchetype().isBatched(); - if (!(varAccess & VarAccessDim::NEURON)) { + const bool batched = (varDims & VarAccessDim::BATCH) && getArchetype().isBatched(); + if (!(varDims & VarAccessDim::NEURON)) { return batched ? "$(_batch_delay_slot)" : "$(_delay_slot)"; } else if (batched) { @@ -118,7 +117,7 @@ std::string CustomUpdateGroupMerged::getVarRefIndex(bool delay, unsigned int var } } else { - return getVarIndex(varAccess, index); + return getVarIndex(varDims, index); } } //---------------------------------------------------------------------------- @@ -188,9 +187,9 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", - [this, &cuEnv](const std::string&, unsigned int d) + [this, &cuEnv](const std::string&, VarAccess d) { - return getVarIndex(d, "$(id_syn)"); + return getVarIndex(d.getDims(), "$(id_syn)"); }); // Create an environment which caches variable references in local variables if they are accessed @@ -198,8 +197,7 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back *this, *this, getTypeContext(), varEnv, backend.getDeviceVarPrefix(), "", "l", [this, &varEnv](const std::string&, const Models::WUVarReference &v) { - return getVarRefIndex(v.getVar().getAccess(SynapseVarAccess::READ_WRITE), - "$(id_syn)"); + return getVarRefIndex(v.getVar().access.getDims(), "$(id_syn)"); }); Transpiler::ErrorHandler errorHandler("Custom update '" + getArchetype().getName() + "' update code"); @@ -209,16 +207,16 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back genPostamble(varRefEnv, *this); } //---------------------------------------------------------------------------- -std::string CustomUpdateWUGroupMergedBase::getVarIndex(unsigned int varAccess, const std::string &index) const +std::string CustomUpdateWUGroupMergedBase::getVarIndex(VarAccessDim varDims, const std::string &index) const { // **YUCK** there's a lot of duplication in these methods - do they belong elsewhere? - return (((varAccess & VarAccessDim::BATCH) && getArchetype().isBatched()) ? "$(_batch_offset) + " : "") + index; + return (((varDims & VarAccessDim::BATCH) && getArchetype().isBatched()) ? "$(_batch_offset) + " : "") + index; } //---------------------------------------------------------------------------- -std::string CustomUpdateWUGroupMergedBase::getVarRefIndex(unsigned int varAccess, const std::string &index) const +std::string CustomUpdateWUGroupMergedBase::getVarRefIndex(VarAccessDim varDims, const std::string &index) const { // **YUCK** there's a lot of duplication in these methods - do they belong elsewhere? - return (((varAccess & VarAccessDim::BATCH) && getArchetype().isBatched()) ? "$(_batch_offset) + " : "") + index; + return (((varDims & VarAccessDim::BATCH) && getArchetype().isBatched()) ? "$(_batch_offset) + " : "") + index; } // ---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 1013b8efe1..9c510a7063 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -27,20 +27,20 @@ using namespace GeNN::CodeGenerator; //-------------------------------------------------------------------------- namespace { -unsigned int getNumVarCopies(unsigned int varAccess, unsigned int batchSize, bool batched = true) +unsigned int getNumVarCopies(VarAccessDim varDims, unsigned int batchSize, bool batched = true) { - return ((varAccess & VarAccessDim::BATCH) && batched) ? batchSize : 1; + return ((varDims & VarAccessDim::BATCH) && batched) ? batchSize : 1; } //-------------------------------------------------------------------------- -unsigned int getNumVarElements(unsigned int varAccess, unsigned int numNeurons) +unsigned int getNumVarElements(VarAccessDim varDims, unsigned int numNeurons) { - return (varAccess & VarAccessDim::NEURON) ? numNeurons : 1; + return (varDims & VarAccessDim::NEURON) ? numNeurons : 1; } //-------------------------------------------------------------------------- -unsigned int getVarSize(unsigned int varAccess, unsigned int numElements, unsigned int batchSize, +unsigned int getVarSize(VarAccessDim varDims, unsigned int numElements, unsigned int batchSize, unsigned int delaySlots = 1, bool batched = true) { - return getNumVarCopies(varAccess, batchSize, batched) * getNumVarElements(varAccess, numElements) * delaySlots; + return getNumVarCopies(varDims, batchSize, batched) * getNumVarElements(varDims, numElements) * delaySlots; } //-------------------------------------------------------------------------- void genSpikeMacros(CodeStream &os, const NeuronGroupInternal &ng, bool trueSpike) @@ -1081,9 +1081,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, std::vector neuronStatePushPullFunctions; for(const auto &var : neuronModel->getVars()) { const auto &varInit = n.second.getVarInitialisers().at(var.name); - const unsigned int varAccess = var.getAccess(VarAccess::READ_WRITE); - const unsigned int numCopies = getNumVarCopies(varAccess, batchSize); - const unsigned int numElements = getNumVarElements(varAccess, n.second.getNumNeurons()); + const unsigned int numCopies = getNumVarCopies(var.access.getDims(), batchSize); + const unsigned int numElements = getNumVarElements(var.access.getDims(), n.second.getNumNeurons()); const size_t count = n.second.isVarQueueRequired(var.name) ? numCopies * numElements * n.second.getNumDelaySlots() : numCopies * numElements; const bool autoInitialized = !Utils::areTokensEmpty(varInit.getCodeTokens()); const auto resolvedType = var.type.resolve(modelMerged.getModel().getTypeContext()); @@ -1153,7 +1152,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, runnerPushFunc, runnerPullFunc, *cs, mem, currentSourcePushPullFunctions, [batchSize, &n](const CurrentSourceInternal&, const Models::Base::Var &var) { - return getVarSize(var.getAccess(VarAccess::READ_WRITE), + return getVarSize(var.access.getDims(), n.second.getNumNeurons(), batchSize); }); genRunnerEGPs(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, @@ -1177,7 +1176,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, runnerPushFunc, runnerPullFunc, model.getCustomUpdates(), mem, statePushPullFunctions, [batchSize](const CustomUpdateInternal &c, const Models::Base::Var &var) { - return getVarSize(var.getAccess(VarAccess::READ_WRITE), + return getVarSize(var.access.getDims(), c.getSize(), batchSize, 1, c.isBatched()); }); @@ -1192,8 +1191,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const size_t count = ((sg->getMatrixType() & SynapseMatrixWeight::KERNEL) ? sg->getKernelSizeFlattened() : sg->getSrcNeuronGroup()->getNumNeurons() * backend.getSynapticMatrixRowStride(*sg)); - return getVarSize(var.getAccess(VarAccess::READ_WRITE), - count, batchSize, 1, c.isBatched()); + return getVarSize(var.access.getDims(), count, + batchSize, 1, c.isBatched()); }); allVarStreams << std::endl; @@ -1274,8 +1273,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerExtraGlobalParamFunc, *sg, mem, [batchSize](const SynapseGroupInternal &sg, const Models::Base::Var &var) { - return getVarSize(var.getAccess(VarAccess::READ_WRITE), - sg.getTrgNeuronGroup()->getNumNeurons(), batchSize); + return getVarSize(var.access.getDims(), sg.getTrgNeuronGroup()->getNumNeurons(), + batchSize); }); } // Loop through fused outgoing synapse populations with weightupdate models that have presynaptic output @@ -1292,8 +1291,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerExtraGlobalParamFunc, *sg, mem, [batchSize, preDelaySlots](const SynapseGroupInternal &sg, const Models::Base::Var &var) { - return getVarSize(var.getAccess(VarAccess::READ_WRITE), - sg.getSrcNeuronGroup()->getNumNeurons(), + return getVarSize(var.access.getDims(), sg.getSrcNeuronGroup()->getNumNeurons(), batchSize, preDelaySlots); }); } @@ -1305,8 +1303,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerExtraGlobalParamFunc, *sg, mem, [batchSize, postDelaySlots](const SynapseGroupInternal &sg, const Models::Base::Var &var) { - return getVarSize(var.getAccess(VarAccess::READ_WRITE), - sg.getTrgNeuronGroup()->getNumNeurons(), + return getVarSize(var.access.getDims(), sg.getTrgNeuronGroup()->getNumNeurons(), batchSize, postDelaySlots); }); } @@ -1407,16 +1404,16 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const auto &varInit = s.second.getWUVarInitialisers().at(wuVar.name); const bool autoInitialized = !Utils::areTokensEmpty(varInit.getCodeTokens()); const auto resolvedType = wuVar.type.resolve(modelMerged.getModel().getTypeContext()); - const unsigned int wuVarAccess = wuVar.getAccess(VarAccess::READ_WRITE); + const unsigned int numCopies = getNumVarCopies(wuVar.access.getDims(), batchSize); if(individualWeights) { const size_t size = (size_t)s.second.getSrcNeuronGroup()->getNumNeurons() * (size_t)backend.getSynapticMatrixRowStride(s.second); genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerPushFunc, runnerPullFunc, resolvedType, wuVar.name + s.second.getName(), s.second.getWUVarLocation(wuVar.name), - autoInitialized, size * getNumVarCopies(wuVarAccess, batchSize), mem, synapseGroupStatePushPullFunctions); + autoInitialized, size * numCopies, mem, synapseGroupStatePushPullFunctions); } else if(kernelWeights) { // Calculate size of kernel - const size_t size = s.second.getKernelSizeFlattened() * getNumVarCopies(wuVarAccess, batchSize); + const size_t size = s.second.getKernelSizeFlattened() * numCopies; // Generate variable genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, @@ -1451,8 +1448,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, s.second, synapseGroupStatePushPullFunctions, [batchSize](const SynapseGroupInternal &sg, const Models::Base::Var &var) { - return getVarSize(var.getAccess(VarAccess::READ_WRITE), - sg.getTrgNeuronGroup()->getNumNeurons(), batchSize); + return getVarSize(var.access.getDims(), sg.getTrgNeuronGroup()->getNumNeurons(), + batchSize); }); } @@ -1464,7 +1461,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, s.second, synapseGroupStatePushPullFunctions, [batchSize, preDelaySlots](const SynapseGroupInternal &sg, const Models::Base::Var &var) { - return getVarSize(var.getAccess(VarAccess::READ_WRITE), sg.getSrcNeuronGroup()->getNumNeurons(), + return getVarSize(var.access.getDims(), sg.getSrcNeuronGroup()->getNumNeurons(), batchSize, preDelaySlots); }); @@ -1478,7 +1475,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, s.second, synapseGroupStatePushPullFunctions, [batchSize, postDelaySlots](const SynapseGroupInternal &sg, const Models::Base::Var &var) { - return getVarSize(var.getAccess(VarAccess::READ_WRITE), sg.getTrgNeuronGroup()->getNumNeurons(), + return getVarSize(var.access.getDims(), sg.getTrgNeuronGroup()->getNumNeurons(), batchSize, postDelaySlots); }); } diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index fecac43363..3f85c0256f 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -21,10 +21,10 @@ using namespace GeNN::Transpiler; namespace { void genVariableFill(EnvironmentExternalBase &env, const std::string &target, const std::string &value, const std::string &idx, const std::string &stride, - unsigned int varAccess, unsigned int batchSize, bool delay = false, unsigned int numDelaySlots = 1) + VarAccessDim varDims, unsigned int batchSize, bool delay = false, unsigned int numDelaySlots = 1) { // Determine number of values to fill in each thread - const unsigned int numValues = ((varAccess & VarAccessDim::BATCH) ? batchSize : 1) * ((delay ? numDelaySlots : 1)); + const unsigned int numValues = ((varDims & VarAccessDim::BATCH) ? batchSize : 1) * ((delay ? numDelaySlots : 1)); // If there's only one, don't generate a loop if(numValues == 1) { @@ -41,10 +41,10 @@ void genVariableFill(EnvironmentExternalBase &env, const std::string &target, co } //-------------------------------------------------------------------------- void genScalarFill(EnvironmentExternalBase &env, const std::string &target, const std::string &value, - unsigned int varAccess, unsigned int batchSize, bool delay = false, unsigned int numDelaySlots = 1) + VarAccessDim varDims, unsigned int batchSize, bool delay = false, unsigned int numDelaySlots = 1) { // Determine number of values to fill in each thread - const unsigned int numValues = ((varAccess & VarAccessDim::BATCH) ? batchSize : 1) * ((delay ? numDelaySlots : 1)); + const unsigned int numValues = ((varDims & VarAccessDim::BATCH) ? batchSize : 1) * ((delay ? numDelaySlots : 1)); // If there's only one, don't generate a loop if(numValues == 1) { @@ -87,11 +87,11 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e }); // If variable has NEURON axis - const unsigned int varAccess = var.getAccess(NeuronVarAccess::READ_WRITE); - if (varAccess & VarAccessDim::NEURON) { + const VarAccessDim varDims = var.access.getDims(); + if (varDims & VarAccessDim::NEURON) { backend.genVariableInit( varEnv, count, "id", - [&adaptor, &fieldGroup, &fieldSuffix, &group, &var, &resolvedType, &varInit, batchSize, count, numDelaySlots] + [&adaptor, &fieldGroup, &fieldSuffix, &group, &var, &resolvedType, &varInit, batchSize, count, numDelaySlots, varDims] (EnvironmentExternalBase &env) { // Generate initial value into temporary variable @@ -105,14 +105,14 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e // Fill value across all delay slots and batches genVariableFill(varInitEnv, "_value", "$(value)", "id", "$(" + count + ")", - varAccess, batchSize, adaptor.isVarDelayed(var.name), numDelaySlots); + varDims, batchSize, adaptor.isVarDelayed(var.name), numDelaySlots); }); } // Otherwise else { backend.genPopVariableInit( varEnv, - [&adaptor, &fieldGroup, &fieldSuffix, &group, &resolvedType, &var, &varInit, batchSize, numDelaySlots] + [&adaptor, &fieldGroup, &fieldSuffix, &group, &resolvedType, &var, &varInit, batchSize, numDelaySlots, varDims] (EnvironmentExternalBase &env) { // Generate initial value into temporary variable @@ -125,7 +125,7 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e prettyPrintStatements(varInit.getCodeTokens(), group.getTypeContext(), varInitEnv, errorHandler); // Fill value across all delay slots and batches - genScalarFill(varInitEnv, "_value", "$(value)", varAccess, + genScalarFill(varInitEnv, "_value", "$(value)", varDims, batchSize, adaptor.isVarDelayed(var.name), numDelaySlots); }); } @@ -185,7 +185,7 @@ void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, // Fill value across all batches genVariableFill(varInitEnv, "_value", "$(value)", "id_syn", stride, - var.getAccess(SynapseVarAccess::READ_WRITE), batchSize); + var.access.getDims(), batchSize); }); } } diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 64e3162582..a1a60f0692 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -40,9 +40,9 @@ void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), csEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", - [batchSize, &ng](const std::string&, unsigned int d) + [batchSize, &ng](const std::string&, VarAccess d) { - return ng.getVarIndex(batchSize, d, "$(id)"); + return ng.getVarIndex(batchSize, d.getDims(), "$(id)"); }); // Pretty print code back to environment @@ -121,9 +121,9 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), psmEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", - [batchSize, &ng](const std::string&, unsigned int d) + [batchSize, &ng](const std::string&, VarAccess d) { - return ng.getVarIndex(batchSize, d, "$(id)"); + return ng.getVarIndex(batchSize, d.getDims(), "$(id)"); }); // Pretty print code back to environment @@ -202,15 +202,15 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back const bool delayed = (getArchetype().getBackPropDelaySteps() != NO_DELAY); EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", - [batchSize, delayed, &synEnv, &ng](const std::string&, unsigned int d) + [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccess d) { - return ng.getReadVarIndex(delayed, batchSize, d, "$(id)"); + return ng.getReadVarIndex(delayed, batchSize, d.getDims(), "$(id)"); }, - [batchSize, delayed, &synEnv, &ng](const std::string&, unsigned int d) + [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccess d) { - return ng.getWriteVarIndex(delayed, batchSize, d, "$(id)"); + return ng.getWriteVarIndex(delayed, batchSize, d.getDims(), "$(id)"); }, - [delayed](const std::string&, unsigned int) + [delayed](const std::string&, VarAccess) { return delayed; }); @@ -242,10 +242,10 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::genCopyDelayedVars(EnvironmentEx // Loop through variables and copy between read and write delay slots // **YUCK** this a bit sketchy as fields may not have been added - could add fields here but need to guarantee uniqueness for(const auto &v : getArchetype().getWUModel()->getPostVars()) { - const unsigned int varAccess = v.getAccess(NeuronVarAccess::READ_WRITE); - if(varAccess & VarAccessMode::READ_WRITE) { - env.print("group->" + v.name + suffix + "[" + ng.getWriteVarIndex(true, batchSize, varAccess, "$(id)") + "] = "); - env.printLine("group->" + v.name + suffix + "[" + ng.getReadVarIndex(true, batchSize, varAccess, "$(id)") + "];"); + if(v.access == VarAccessMode::READ_WRITE) { + const VarAccessDim varDims = v.access.getDims(); + env.print("group->" + v.name + suffix + "[" + ng.getWriteVarIndex(true, batchSize, varDims, "$(id)") + "] = "); + env.printLine("group->" + v.name + suffix + "[" + ng.getReadVarIndex(true, batchSize, varDims, "$(id)") + "];"); } } } @@ -294,15 +294,15 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back const bool delayed = (getArchetype().getDelaySteps() != NO_DELAY); EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", - [batchSize, delayed, &ng](const std::string&, unsigned int d) + [batchSize, delayed, &ng](const std::string&, VarAccess d) { - return ng.getReadVarIndex(delayed, batchSize, d, "$(id)"); + return ng.getReadVarIndex(delayed, batchSize, d.getDims(), "$(id)"); }, - [batchSize, delayed, &ng](const std::string&, unsigned int d) + [batchSize, delayed, &ng](const std::string&, VarAccess d) { - return ng.getWriteVarIndex(delayed, batchSize, d, "$(id)"); + return ng.getWriteVarIndex(delayed, batchSize, d.getDims(), "$(id)"); }, - [delayed](const std::string&, unsigned int) + [delayed](const std::string&, VarAccess) { return delayed; }); @@ -334,10 +334,10 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::genCopyDelayedVars(EnvironmentEx // Loop through variables and copy between read and write delay slots // **YUCK** this a bit sketchy as fields may not have been added - could add fields here but need to guarantee uniqueness for(const auto &v : getArchetype().getWUModel()->getPreVars()) { - const unsigned int varAccess = v.getAccess(NeuronVarAccess::READ_WRITE); - if(varAccess & VarAccessMode::READ_WRITE) { - env.print("group->" + v.name + suffix + "[" + ng.getWriteVarIndex(true, batchSize, varAccess, "$(id)") + "] = "); - env.printLine("group->" + v.name + suffix + "[" + ng.getReadVarIndex(true, batchSize, varAccess, "$(id)") + "];"); + if(v.access == VarAccessMode::READ_WRITE) { + const VarAccessDim varDims = v.access.getDims(); + env.print("group->" + v.name + suffix + "[" + ng.getWriteVarIndex(true, batchSize, varDims, "$(id)") + "] = "); + env.printLine("group->" + v.name + suffix + "[" + ng.getReadVarIndex(true, batchSize, varDims, "$(id)") + "];"); } } } @@ -512,17 +512,17 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // **NOTE** always copy variables if variable is delayed EnvironmentLocalVarCache neuronVarEnv( *this, *this, getTypeContext(), neuronEnv, backend.getDeviceVarPrefix(), "", "l", - [batchSize, &neuronEnv, this](const std::string &varName, unsigned int d) + [batchSize, &neuronEnv, this](const std::string &varName, VarAccess d) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); - return getReadVarIndex(delayed, batchSize, d, "$(id)") ; + return getReadVarIndex(delayed, batchSize, d.getDims(), "$(id)") ; }, - [batchSize, &neuronEnv, this](const std::string &varName, unsigned int d) + [batchSize, &neuronEnv, this](const std::string &varName, VarAccess d) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); - return getWriteVarIndex(delayed, batchSize, d, "$(id)") ; + return getWriteVarIndex(delayed, batchSize, d.getDims(), "$(id)") ; }, - [this](const std::string &varName, unsigned int) + [this](const std::string &varName, VarAccess) { return (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); }); @@ -709,7 +709,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // If previous spike times are required, copy times from register if(getArchetype().isPrevSpikeTimeRequired()) { - neuronVarEnv.printLine("$(_prev_st)[" + getWriteVarIndex(true, batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON "$(id)") + "] = $(prev_st);"); + neuronVarEnv.printLine("$(_prev_st)[" + getWriteVarIndex(true, batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id)") + "] = $(prev_st);"); } // Loop through outgoing synapse groups with some sort of presynaptic code @@ -741,11 +741,12 @@ void NeuronUpdateGroupMerged::generateWUVarUpdate(const BackendBase &backend, En } } //-------------------------------------------------------------------------- -std::string NeuronUpdateGroupMerged::getVarIndex(unsigned int batchSize, unsigned int varAccess, const std::string &index) const +std::string NeuronUpdateGroupMerged::getVarIndex(unsigned int batchSize, VarAccessDim varDims, + const std::string &index) const { // **YUCK** there's a lot of duplication in these methods - do they belong elsewhere? - const bool batched = ((varAccess & VarAccessDim::BATCH) && batchSize >= 1); - if (!(varAccess & VarAccessDim::NEURON)) { + const bool batched = ((varDims & VarAccessDim::BATCH) && batchSize > 1); + if (!(varDims & VarAccessDim::NEURON)) { return batched ? "$(batch)" : "0"; } else if(batched) { @@ -756,11 +757,12 @@ std::string NeuronUpdateGroupMerged::getVarIndex(unsigned int batchSize, unsigne } } //-------------------------------------------------------------------------- -std::string NeuronUpdateGroupMerged::getReadVarIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const +std::string NeuronUpdateGroupMerged::getReadVarIndex(bool delay, unsigned int batchSize, + VarAccessDim varDims, const std::string &index) const { if(delay) { - const bool batched = ((varAccess & VarAccessDim::BATCH) && batchSize >= 1); - if (!(varAccess & VarAccessDim::NEURON)) { + const bool batched = ((varDims & VarAccessDim::BATCH) && batchSize > 1); + if (!(varDims & VarAccessDim::NEURON)) { return batched ? "$(_read_batch_delay_slot)" : "$(_read_delay_slot)"; } else if(batched) { @@ -771,15 +773,16 @@ std::string NeuronUpdateGroupMerged::getReadVarIndex(bool delay, unsigned int ba } } else { - return getVarIndex(batchSize, varAccess, index); + return getVarIndex(batchSize, varDims, index); } } //-------------------------------------------------------------------------- -std::string NeuronUpdateGroupMerged::getWriteVarIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const +std::string NeuronUpdateGroupMerged::getWriteVarIndex(bool delay, unsigned int batchSize, + VarAccessDim varDims, const std::string &index) const { if(delay) { - const bool batched = ((varAccess & VarAccessDim::BATCH) && batchSize >= 1); - if (!(varAccess & VarAccessDim::NEURON)) { + const bool batched = ((varDims & VarAccessDim::BATCH) && batchSize > 1); + if (!(varDims & VarAccessDim::NEURON)) { return batched ? "$(_write_batch_delay_slot)" : "$(_write_delay_slot)"; } else if (batched) { @@ -790,7 +793,7 @@ std::string NeuronUpdateGroupMerged::getWriteVarIndex(bool delay, unsigned int b } } else { - return getVarIndex(batchSize, varAccess, index); + return getVarIndex(batchSize, varDims, index); } } //---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 9bddae801c..2974acbce4 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -30,15 +30,15 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa // Substitute names of pre and postsynaptic weight update variable synEnv.template addVars( backend.getDeviceVarPrefix(), - [&sg, batchSize](unsigned int a, const std::string&) + [&sg, batchSize](VarAccess a, const std::string&) { - return sg.getPreWUVarIndex(batchSize, a, "$(id_pre)"); + return sg.getPreWUVarIndex(batchSize, a.getDims(), "$(id_pre)"); }, "", true); synEnv.template addVars( backend.getDeviceVarPrefix(), - [&sg, batchSize](unsigned int a, const std::string&) + [&sg, batchSize](VarAccess a, const std::string&) { - return sg.getPostWUVarIndex(batchSize, a, "$(id_post)"); + return sg.getPostWUVarIndex(batchSize, a.getDims(), "$(id_post)"); }, "", true); @@ -78,9 +78,9 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa if (sg.getArchetype().getMatrixType() & SynapseMatrixWeight::INDIVIDUAL) { synEnv.template addVars( backend.getDeviceVarPrefix(), - [&sg, batchSize](unsigned int a, const std::string&) + [&sg, batchSize](VarAccess a, const std::string&) { - return sg.getSynVarIndex(batchSize, a, "$(id_syn)"); + return sg.getSynVarIndex(batchSize, a.getDims(), "$(id_syn)"); }); } // Otherwise, if weights are procedual @@ -121,9 +121,9 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa synEnv.template addVars( backend.getDeviceVarPrefix(), - [&sg, batchSize](unsigned int a, const std::string&) + [&sg, batchSize](VarAccess a, const std::string&) { - return sg.getKernelVarIndex(batchSize, a, "$(id_kernel)"); + return sg.getKernelVarIndex(batchSize, a.getDims(), "$(id_kernel)"); }); } @@ -179,6 +179,7 @@ bool SynapseGroupMergedBase::isToeplitzConnectivityInitDerivedParamHeterogeneous //---------------------------------------------------------------------------- std::string SynapseGroupMergedBase::getPreSlot(unsigned int batchSize) const { + // **TODO** this is basically VarAccessDim::BATCH if(getArchetype().getSrcNeuronGroup()->isDelayRequired()) { return (batchSize == 1) ? "$(_pre_delay_slot)" : "$(_pre_batch_delay_slot)"; } @@ -189,6 +190,7 @@ std::string SynapseGroupMergedBase::getPreSlot(unsigned int batchSize) const //---------------------------------------------------------------------------- std::string SynapseGroupMergedBase::getPostSlot(unsigned int batchSize) const { + // **TODO** this is basically VarAccessDim::BATCH if(getArchetype().getTrgNeuronGroup()->isDelayRequired()) { return (batchSize == 1) ? "$(_post_delay_slot)" : "$(_post_batch_delay_slot)"; } @@ -208,20 +210,10 @@ std::string SynapseGroupMergedBase::getPostDenDelayIndex(unsigned int batchSize, return "(((*$(_den_delay_ptr) + " + offset + ") % " + std::to_string(getArchetype().getMaxDendriticDelayTimesteps()) + ") * $(num_post)) + " + batchID; } } -//---------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getPreVarIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const -{ - return getVarIndex(delay, batchSize, varAccess, VarAccessDim::PRE_NEURON, index, "pre"); -} -//-------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getPostVarIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const -{ - return getVarIndex(delay, batchSize, varAccess, VarAccessDim::POST_NEURON, index, "post"); -} //-------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getPrePrevSpikeTimeIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const +std::string SynapseGroupMergedBase::getPrePrevSpikeTimeIndex(bool delay, unsigned int batchSize, VarAccessDim varDims, const std::string &index) const { - const bool batched = ((varAccess & VarAccessDim::BATCH) && batchSize > 1); + const bool batched = ((varDims & VarAccessDim::BATCH) && batchSize > 1); if(delay) { return (batched ? "$(_pre_prev_spike_time_batch_delay_offset) + " : "$(_pre_prev_spike_time_delay_offset) + " ) + index; @@ -231,9 +223,9 @@ std::string SynapseGroupMergedBase::getPrePrevSpikeTimeIndex(bool delay, unsigne } } //-------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getPostPrevSpikeTimeIndex(bool delay, unsigned int batchSize, unsigned int varAccess, const std::string &index) const +std::string SynapseGroupMergedBase::getPostPrevSpikeTimeIndex(bool delay, unsigned int batchSize, VarAccessDim varDims, const std::string &index) const { - const bool batched = ((varAccess & VarAccessDim::BATCH) && batchSize > 1); + const bool batched = ((varDims & VarAccessDim::BATCH) && batchSize > 1); if(delay) { return (batched ? "$(_post_prev_spike_time_batch_delay_offset) + " : "$(_post_prev_spike_time_delay_offset) + ") + index; @@ -243,24 +235,24 @@ std::string SynapseGroupMergedBase::getPostPrevSpikeTimeIndex(bool delay, unsign } } //-------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getSynVarIndex(unsigned int batchSize, unsigned int varAccess, const std::string &index) const +std::string SynapseGroupMergedBase::getSynVarIndex(unsigned int batchSize, VarAccessDim varDims, const std::string &index) const { - const bool batched = ((varAccess & VarAccessDim::BATCH) && batchSize > 1); + const bool batched = ((varDims & VarAccessDim::BATCH) && batchSize > 1); return (batched ? "$(_syn_batch_offset) + " : "") + index; } //-------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getKernelVarIndex(unsigned int batchSize, unsigned int varAccess, const std::string &index) const +std::string SynapseGroupMergedBase::getKernelVarIndex(unsigned int batchSize, VarAccessDim varDims, const std::string &index) const { - const bool batched = ((varAccess & VarAccessDim::BATCH) && batchSize > 1); + const bool batched = ((varDims & VarAccessDim::BATCH) && batchSize > 1); return (batched ? "$(_kern_batch_offset) + " : "") + index; } //---------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getVarIndex(bool delay, unsigned int batchSize, unsigned int varAccess, - VarAccessDim neuronAxis, const std::string &index, const std::string &prefix) const +std::string SynapseGroupMergedBase::getPrePostVarIndex(bool delay, unsigned int batchSize, VarAccessDim varDims, + const std::string &index, const std::string &prefix) const { - const bool batched = ((varAccess & VarAccessDim::BATCH) && batchSize > 1); + const bool batched = ((varDims & VarAccessDim::BATCH) && batchSize > 1); if (delay) { - if (!(varAccess & neuronAxis)) { + if (!(varDims & VarAccessDim::NEURON)) { return (batched ? "$(_" + prefix + "_batch_delay_slot)" : "$(_" + prefix + "_delay_slot)"); } else if(batched) { @@ -271,7 +263,7 @@ std::string SynapseGroupMergedBase::getVarIndex(bool delay, unsigned int batchSi } } else { - if (!(varAccess & neuronAxis)) { + if (!(varDims & VarAccessDim::NEURON)) { return batched ? "$(batch)" : "0"; } else if (batched) { diff --git a/src/genn/genn/currentSourceModels.cc b/src/genn/genn/currentSourceModels.cc index 61838f7238..219421a3fa 100644 --- a/src/genn/genn/currentSourceModels.cc +++ b/src/genn/genn/currentSourceModels.cc @@ -35,12 +35,9 @@ void Base::validate(const std::unordered_map ¶mValues, // If any variables have a reduction access mode, give an error const auto vars = getVars(); if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v) - { - return (v.getAccessMode() & VarAccessModeAttribute::REDUCE); - })) + [](const Models::Base::Var &v){ return !v.access.isValidNeuron(); })) { - throw std::runtime_error("Current source models cannot include variables with REDUCE access modes - they are only supported by custom update models"); + throw std::runtime_error("Current source model variables much have NeuronVarAccess access type"); } } } // namespace GeNN::CurrentSourceModels \ No newline at end of file diff --git a/src/genn/genn/customConnectivityUpdate.cc b/src/genn/genn/customConnectivityUpdate.cc index e78da94859..12da9eb75c 100644 --- a/src/genn/genn/customConnectivityUpdate.cc +++ b/src/genn/genn/customConnectivityUpdate.cc @@ -180,7 +180,7 @@ void CustomConnectivityUpdate::finalise(double dt, unsigned int batchSize) if (std::any_of(getPreVarReferences().cbegin(), getPreVarReferences().cend(), [](const auto &v) { - return (v.second.getVar().getAccess(NeuronVarAccess::READ_WRITE) & VarAccessDim::BATCH); + return (v.second.getVar().access.getDims() & VarAccessDim::BATCH); })) { throw std::runtime_error("Presynaptic variables referenced by CustomConnectivityUpdate must be SHARED across batches"); @@ -190,7 +190,7 @@ void CustomConnectivityUpdate::finalise(double dt, unsigned int batchSize) if (std::any_of(getPostVarReferences().cbegin(), getPostVarReferences().cend(), [](const auto &v) { - return (v.second.getVar().getAccess(NeuronVarAccess::READ_WRITE)& VarAccessDim::BATCH); + return (v.second.getVar().access.getDims() & VarAccessDim::BATCH); })) { throw std::runtime_error("Postsynaptic variables referenced by CustomConnectivityUpdate must be SHARED across batches"); diff --git a/src/genn/genn/customConnectivityUpdateModels.cc b/src/genn/genn/customConnectivityUpdateModels.cc index ff1fa26425..c5425ee846 100644 --- a/src/genn/genn/customConnectivityUpdateModels.cc +++ b/src/genn/genn/customConnectivityUpdateModels.cc @@ -47,38 +47,31 @@ void Base::validate(const std::unordered_map ¶mValues, Utils::validateVecNames(getVarRefs(), "Synapse variable reference"); Utils::validateVecNames(getPreVarRefs(), "Presynaptic variable reference"); Utils::validateVecNames(getPostVarRefs(), "Postsynaptic variable reference"); - // Validate variable initialisers Utils::validateInitialisers(preVars, preVarValues, "presynaptic variable", description); Utils::validateInitialisers(postVars, postVarValues, "postsynaptic variable", description); + // Validate variable reference initialisers Utils::validateInitialisers(getVarRefs(), varRefTargets, "variable reference", description); Utils::validateInitialisers(getPreVarRefs(), preVarRefTargets, "presynaptic variable reference", description); Utils::validateInitialisers(getPostVarRefs(), postVarRefTargets, "postsynaptic variable reference", description); - - // If any variables have a reduction access mode, give an error - // **YUCK** copy-paste from WUM - could go in helper/Models::Base + // Check variables have suitable access types if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v){ return (v.getAccessMode() & VarAccessModeAttribute::REDUCE); }) - || std::any_of(preVars.cbegin(), preVars.cend(), - [](const Models::Base::Var &v){ return (v.getAccessMode() & VarAccessModeAttribute::REDUCE); }) - || std::any_of(postVars.cbegin(), postVars.cend(), - [](const Models::Base::Var &v){ return (v.getAccessMode() & VarAccessModeAttribute::REDUCE); })) + [](const Models::Base::Var &v){ return !v.access.isValidSynapse(); })) { - throw std::runtime_error("Custom connectivity update models cannot include variables with REDUCE access modes - they are only supported by custom update models"); + throw std::runtime_error("Custom connectivity update models variables much have SynapseVarAccess access type"); } - - // If any variables have shared neuron duplication mode, give an error - // **YUCK** copy-paste from WUM - could go in helper/Models::Base - if (std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v) - { - return (v.getAccess(SynapseVarAccess::READ_WRITE) & VarAccessDim::SHARED_NEURON); - })) + if(std::any_of(preVars.cbegin(), preVars.cend(), + [](const Models::Base::Var &v){ return !v.access.isValidNeuron(); })) + { + throw std::runtime_error("Custom connectivity update models presynaptic variables much have NeuronVarAccess access type"); + } + if(std::any_of(postVars.cbegin(), postVars.cend(), + [](const Models::Base::Var &v){ return !v.access.isValidNeuron(); })) { - throw std::runtime_error("Custom connectivity update models cannot include variables with SHARED_NEURON access modes - they are only supported on pre, postsynaptic or neuron variables"); + throw std::runtime_error("Custom connectivity update models postsynaptic variables much have NeuronVarAccess access type"); } } } // namespace GeNN::CustomConnectivityUpdateModels \ No newline at end of file diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index 6b473f24d5..b6b9d29fe9 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -139,12 +139,12 @@ CustomUpdate::CustomUpdate(const std::string &name, const std::string &updateGro m_PerNeuron = std::any_of(m_VarReferences.cbegin(), m_VarReferences.cend(), [](const auto& v) { - return (v.second.getVar().getAccess(NeuronVarAccess::READ_WRITE) & VarAccessDim::NEURON); + return (v.second.getVar().access.getDims() & VarAccessDim::NEURON); }); m_PerNeuron |= std::any_of(modelVars.cbegin(), modelVars.cend(), [](const Models::Base::Var& v) { - return (v.getAccess(NeuronVarAccess::READ_WRITE) & VarAccessDim::NEURON); + return (v.access.getDims() & VarAccessDim::NEURON); }); // Loop through all variable references @@ -153,7 +153,7 @@ CustomUpdate::CustomUpdate(const std::string &name, const std::string &updateGro // If custom update is per-neuron, check that any variable references to variables without NEURON axis are read-only // **NOTE** if custom update isn't per-neuron, it's totally fine to write to SHARED_NEURON variables - if(m_PerNeuron && !(varRef.getVar().getAccess(NeuronVarAccess::READ_WRITE) & VarAccessDim::NEURON) + if(m_PerNeuron && !(varRef.getVar().access.getDims() & VarAccessDim::NEURON) && (modelVarRef.access == VarAccessMode::READ_WRITE)) { throw std::runtime_error("Variable references to SHARED_NEURON variables in per-neuron custom updates cannot be read-write."); @@ -179,7 +179,7 @@ void CustomUpdate::finalise(double dt, unsigned int batchSize) CustomUpdateBase::finalise(dt); // Check variable reference batching - checkVarReferenceBatching(m_VarReferences, batchSize); + checkVarReferenceBatching(m_VarReferences, batchSize); // If any variable references have delays auto delayRef = std::find_if(m_VarReferences.cbegin(), m_VarReferences.cend(), @@ -218,8 +218,8 @@ boost::uuids::detail::sha1::digest_type CustomUpdate::getHashDigest() const // Update hash with whether variable references require delay Utils::updateHash((v.second.getDelayNeuronGroup() == nullptr), hash); - // Update hash with target variable access mode as this effects indexing code - Utils::updateHash(v.second.getVar().getAccess(NeuronVarAccess::READ_WRITE), hash); + // Update hash with target variable dimensions as this effects indexing code + Utils::updateHash(v.second.getVar().access.getDims(), hash); } return hash.get_digest(); } @@ -267,7 +267,7 @@ CustomUpdateWU::CustomUpdateWU(const std::string &name, const std::string &updat // Give error if custom update model includes any shared neuron variables // **NOTE** because there's no way to reference neuron variables with WUVarReferences, // this safely checks for attempts to do neuron reductions - const auto vars = getCustomUpdateModel()->getVars(); + /*const auto vars = getCustomUpdateModel()->getVars(); if (std::any_of(vars.cbegin(), vars.cend(), [](const Models::Base::Var &v) { @@ -275,7 +275,7 @@ CustomUpdateWU::CustomUpdateWU(const std::string &name, const std::string &updat })) { throw std::runtime_error("Custom weight updates cannot use models with SHARED_NEURON variables."); - } + }*/ // If this is a transpose operation if(isTransposeOperation()) { @@ -311,7 +311,7 @@ void CustomUpdateWU::finalise(double dt, unsigned int batchSize) CustomUpdateBase::finalise(dt); // Check variable reference types - checkVarReferenceBatching(m_VarReferences, batchSize); + checkVarReferenceBatching(m_VarReferences, batchSize); } //---------------------------------------------------------------------------- bool CustomUpdateWU::isTransposeOperation() const @@ -340,8 +340,8 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWU::getHashDigest() const // Update hash with whether variable references require transpose Utils::updateHash((v.second.getTransposeSynapseGroup() == nullptr), hash); - // Update hash with access mode of target variable as this effects indexing code - Utils::updateHash(v.second.getVar().getAccess(SynapseVarAccess::READ_WRITE), hash); + // Update hash with access mode of target variable dimensions as this effects indexing code + Utils::updateHash(v.second.getVar().access.getDims(), hash); } return hash.get_digest(); diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index fe8cc9ec6d..1029b238c1 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -76,7 +76,7 @@ std::string VarReference::getTargetName() const bool VarReference::isDuplicated() const { // If target variable has BATCH dimension - if(getVar().getAccess(NeuronVarAccess::READ_WRITE) & VarAccessDim::BATCH) { + if(getVar().access.getDims() & VarAccessDim::BATCH) { return std::visit( Utils::Overload{ [](const CURef &ref) { return ref.group->isBatched(); }, @@ -177,7 +177,7 @@ std::string WUVarReference::getTargetName() const bool WUVarReference::isDuplicated() const { // If target variable has BATCH dimension - if(getVar().getAccess(SynapseVarAccess::READ_WRITE) & VarAccessDim::BATCH) { + if(getVar().access.getDims() & VarAccessDim::BATCH) { return std::visit( Utils::Overload{ [](const CURef &ref) { return ref.group->isBatched(); }, @@ -332,8 +332,8 @@ WUVarReference::WUVarReference(size_t varIndex, const Models::Base::VarVec &varV } // Check duplicatedness of variables - if((getVar().getAccess(SynapseVarAccess::READ_WRITE) & VarAccessDim::BATCH) - != (getTransposeVar().getAccess(SynapseVarAccess::READ_WRITE) & VarAccessDim::BATCH)) + if((getVar().access.getDims() & VarAccessDim::BATCH) + != (getTransposeVar().access.getDims() & VarAccessDim::BATCH)) { throw std::runtime_error("Transpose updates can only be performed on similarly batched variables"); } @@ -385,7 +385,7 @@ void updateHash(const Base::Var &v, boost::uuids::detail::sha1 &hash) { Utils::updateHash(v.name, hash); Type::updateHash(v.type, hash); - Utils::updateHash(v.access, hash); + v.access.updateHash(hash); } //---------------------------------------------------------------------------- void updateHash(const Base::VarRef &v, boost::uuids::detail::sha1 &hash) diff --git a/src/genn/genn/neuronModels.cc b/src/genn/genn/neuronModels.cc index da52358947..3205d6d924 100644 --- a/src/genn/genn/neuronModels.cc +++ b/src/genn/genn/neuronModels.cc @@ -50,9 +50,9 @@ void Base::validate(const std::unordered_map ¶mValues, // If any variables have a reduction access mode, give an error const auto vars = getVars(); if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v){ return (v.getAccessMode() & VarAccessModeAttribute::REDUCE); })) + [](const Models::Base::Var &v){ return !v.access.isValidNeuron(); })) { - throw std::runtime_error("Neuron models cannot include variables with REDUCE access modes - they are only supported by custom update models"); + throw std::runtime_error("Neuron model variables much have NeuronVarAccess access type"); } } } // namespace GeNN::NeuronModels diff --git a/src/genn/genn/postsynapticModels.cc b/src/genn/genn/postsynapticModels.cc index 2713a9b89b..bf92e3976d 100644 --- a/src/genn/genn/postsynapticModels.cc +++ b/src/genn/genn/postsynapticModels.cc @@ -33,11 +33,12 @@ void Base::validate(const std::unordered_map ¶mValues, // Superclass Models::Base::validate(paramValues, varValues, description); + // If any variables have a reduction access mode, give an error const auto vars = getVars(); if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v){ return (v.getAccessMode() & VarAccessModeAttribute::REDUCE); })) + [](const Models::Base::Var &v){ return !v.access.isValidNeuron(); })) { - throw std::runtime_error("Postsynaptic models cannot include variables with REDUCE access modes - they are only supported by custom update models"); + throw std::runtime_error("Postsynaptic model variables much have NeuronVarAccess access type"); } } } // namespace GeNN::PostsynapticModels \ No newline at end of file diff --git a/src/genn/genn/weightUpdateModels.cc b/src/genn/genn/weightUpdateModels.cc index d1234f78eb..7ecd28c33b 100644 --- a/src/genn/genn/weightUpdateModels.cc +++ b/src/genn/genn/weightUpdateModels.cc @@ -77,37 +77,32 @@ void Base::validate(const std::unordered_map ¶mValues, // Superclass Models::Base::validate(paramValues, varValues, description); + + const auto preVars = getPreVars(); + const auto postVars = getPostVars(); Utils::validateVecNames(getPreVars(), "Presynaptic variable"); Utils::validateVecNames(getPostVars(), "Presynaptic variable"); - // If any variables have a reduction access mode, give an error + // Validate variable initialisers + Utils::validateInitialisers(preVars, preVarValues, "presynaptic variable", description); + Utils::validateInitialisers(postVars, postVarValues, "postsynaptic variable", description); + + // Check variables have suitable access types const auto vars = getVars(); - const auto preVars = getPreVars(); - const auto postVars = getPostVars(); if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v){ return (v.getAccessMode() & VarAccessModeAttribute::REDUCE); }) - || std::any_of(preVars.cbegin(), preVars.cend(), - [](const Models::Base::Var &v){ return (v.getAccessMode() & VarAccessModeAttribute::REDUCE); }) - || std::any_of(postVars.cbegin(), postVars.cend(), - [](const Models::Base::Var &v){ return (v.getAccessMode() & VarAccessModeAttribute::REDUCE); })) + [](const Models::Base::Var &v){ return !v.access.isValidSynapse(); })) { - throw std::runtime_error("Weight update models cannot include variables with REDUCE access modes - they are only supported by custom update models"); + throw std::runtime_error("Weight update models variables much have SynapseVarAccess access type"); } - - // Validate variable reference initialisers - Utils::validateInitialisers(preVars, preVarValues, "presynaptic variable", description); - - // Validate variable reference initialisers - Utils::validateInitialisers(postVars, postVarValues, "postsynaptic variable", description); - - // If any variables have shared neuron duplication mode, give an error - if (std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v) - { - return (v.getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::SHARED_NEURON); - })) + if(std::any_of(preVars.cbegin(), preVars.cend(), + [](const Models::Base::Var &v){ return !v.access.isValidNeuron(); })) + { + throw std::runtime_error("Weight update models presynaptic variables much have NeuronVarAccess access type"); + } + if(std::any_of(postVars.cbegin(), postVars.cend(), + [](const Models::Base::Var &v){ return !v.access.isValidNeuron(); })) { - throw std::runtime_error("Weight update models cannot include variables with SHARED_NEURON access modes - they are only supported on pre, postsynaptic or neuron variables"); + throw std::runtime_error("Weight update models postsynaptic variables much have NeuronVarAccess access type"); } } } // namespace WeightUpdateModels \ No newline at end of file From cca56b7300666946481ce6a6424e7fdb0e14f116 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 18 Aug 2023 13:08:05 +0100 Subject: [PATCH 06/60] sprinkled template around to appease GCC --- include/genn/backends/single_threaded_cpu/backend.h | 2 +- include/genn/genn/code_generator/backendBase.h | 2 +- include/genn/genn/customUpdate.h | 6 +++--- .../code_generator/customConnectivityUpdateGroupMerged.cc | 6 +++--- src/genn/genn/code_generator/initGroupMerged.cc | 8 ++++---- src/genn/genn/customConnectivityUpdate.cc | 4 ++-- src/genn/genn/customUpdate.cc | 4 ++-- 7 files changed, 16 insertions(+), 16 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index e79b961be7..6134f54b41 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -254,7 +254,7 @@ class BACKEND_EXPORT Backend : public BackendBase // If variable is a reduction target, copy value from register straight back into global memory if(v.access & VarAccessModeAttribute::REDUCE) { const std::string idx = env.getName(idxName); - env.getStream() << "group->" << v.name << "[" << cg.getVarIndex(v.access.getDims(), idx) << "] = " << env[v.name] << ";" << std::endl; + env.getStream() << "group->" << v.name << "[" << cg.getVarIndex(v.access.template getDims(), idx) << "] = " << env[v.name] << ";" << std::endl; } } diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 46f1c9fe04..9ae0667197 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -573,7 +573,7 @@ class GENN_EXPORT BackendBase const auto resolvedType = v.type.resolve(cg.getTypeContext()); os << resolvedType.getName() << " _lr" << v.name << " = " << getReductionInitialValue(v.access, resolvedType) << ";" << std::endl; reductionTargets.push_back({v.name, resolvedType, v.access, - cg.getVarIndex(v.access.getDims(), idx)}); + cg.getVarIndex(v.access.template getDims(), idx)}); } } diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index 82b174ad4b..d7fb8b5e26 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -95,7 +95,7 @@ class GENN_EXPORT CustomUpdateBase [reduceDim](const Models::Base::Var &v) { return ((v.access & VarAccessModeAttribute::REDUCE) - && !(v.access.getDims() & reduceDim)); + && !(v.access.template getDims() & reduceDim)); })) { return true; @@ -107,7 +107,7 @@ class GENN_EXPORT CustomUpdateBase // and the variable it targets doesn't have reduction dimension const auto &varRef = varRefs.at(modelVarRef.name); if ((modelVarRef.access & VarAccessModeAttribute::REDUCE) - && !(varRef.getVar().access.getDims() & reduceDim)) + && !(varRef.getVar().access.template getDims() & reduceDim)) { return true; } @@ -135,7 +135,7 @@ class GENN_EXPORT CustomUpdateBase // If custom update is batched, check that any variable references to variables that aren't batched are read-only // **NOTE** if custom update isn't batched, it's totally fine to write to shared variables - if(m_Batched && !(varRef.getVar().access.getDims() & VarAccessDim::BATCH) + if(m_Batched && !(varRef.getVar().access.template getDims() & VarAccessDim::BATCH) && (modelVarRef.access == VarAccessMode::READ_WRITE)) { throw std::runtime_error("Variable references to non-batched variables in batched custom updates cannot be read-write."); diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 6ac40c94d6..a46953d5d5 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -37,11 +37,11 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t { boost::uuids::detail::sha1 hashA; Type::updateHash(a.getVar().type, hashA); - Utils::updateHash(a.getVar().access.getDims(), hashA); + Utils::updateHash(a.getVar().access.template getDims(), hashA); boost::uuids::detail::sha1 hashB; Type::updateHash(b.getVar().type, hashB); - Utils::updateHash(b.getVar().access.getDims(), hashB); + Utils::updateHash(b.getVar().access.template getDims(), hashB); return (hashA.get_digest() < hashB.get_digest()); }); @@ -527,4 +527,4 @@ bool CustomConnectivityHostUpdateGroupMerged::isParamHeterogeneous(const std::st bool CustomConnectivityHostUpdateGroupMerged::isDerivedParamHeterogeneous(const std::string &name) const { return isParamValueHeterogeneous(name, [](const CustomConnectivityUpdateInternal &cg) { return cg.getDerivedParams(); }); -} \ No newline at end of file +} diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 3f85c0256f..2a98bb218d 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -87,7 +87,7 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e }); // If variable has NEURON axis - const VarAccessDim varDims = var.access.getDims(); + const VarAccessDim varDims = var.access.template getDims(); if (varDims & VarAccessDim::NEURON) { backend.genVariableInit( varEnv, count, "id", @@ -173,19 +173,19 @@ void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, genSynapseVariableRowInitFn(varEnv, [&group, &resolvedType, &stride, &var, &varInit, batchSize] (EnvironmentExternalBase &env) - { + { // Generate initial value into temporary variable EnvironmentGroupMergedField varInitEnv(env, group); varInitEnv.getStream() << resolvedType.getName() << " initVal;" << std::endl; varInitEnv.add(resolvedType, "value", "initVal"); - + // Pretty print variable initialisation code Transpiler::ErrorHandler errorHandler("Variable '" + var.name + "' init code" + std::to_string(group.getIndex())); prettyPrintStatements(varInit.getCodeTokens(), group.getTypeContext(), varInitEnv, errorHandler); // Fill value across all batches genVariableFill(varInitEnv, "_value", "$(value)", "id_syn", stride, - var.access.getDims(), batchSize); + var.access.template getDims(), batchSize); }); } } diff --git a/src/genn/genn/customConnectivityUpdate.cc b/src/genn/genn/customConnectivityUpdate.cc index 12da9eb75c..0b4e105d25 100644 --- a/src/genn/genn/customConnectivityUpdate.cc +++ b/src/genn/genn/customConnectivityUpdate.cc @@ -180,7 +180,7 @@ void CustomConnectivityUpdate::finalise(double dt, unsigned int batchSize) if (std::any_of(getPreVarReferences().cbegin(), getPreVarReferences().cend(), [](const auto &v) { - return (v.second.getVar().access.getDims() & VarAccessDim::BATCH); + return (v.second.getVar().access.template getDims() & VarAccessDim::BATCH); })) { throw std::runtime_error("Presynaptic variables referenced by CustomConnectivityUpdate must be SHARED across batches"); @@ -190,7 +190,7 @@ void CustomConnectivityUpdate::finalise(double dt, unsigned int batchSize) if (std::any_of(getPostVarReferences().cbegin(), getPostVarReferences().cend(), [](const auto &v) { - return (v.second.getVar().access.getDims() & VarAccessDim::BATCH); + return (v.second.getVar().access.template getDims() & VarAccessDim::BATCH); })) { throw std::runtime_error("Postsynaptic variables referenced by CustomConnectivityUpdate must be SHARED across batches"); diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index b6b9d29fe9..40762b9339 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -139,12 +139,12 @@ CustomUpdate::CustomUpdate(const std::string &name, const std::string &updateGro m_PerNeuron = std::any_of(m_VarReferences.cbegin(), m_VarReferences.cend(), [](const auto& v) { - return (v.second.getVar().access.getDims() & VarAccessDim::NEURON); + return (v.second.getVar().access.template getDims() & VarAccessDim::NEURON); }); m_PerNeuron |= std::any_of(modelVars.cbegin(), modelVars.cend(), [](const Models::Base::Var& v) { - return (v.access.getDims() & VarAccessDim::NEURON); + return (v.access.template getDims() & VarAccessDim::NEURON); }); // Loop through all variable references From 0b89d8ed5a7c3a378cb51dde89767df096152d6f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 18 Aug 2023 13:34:22 +0100 Subject: [PATCH 07/60] added CustomUpdateVarAccess and more elegant isValid<> method to VarAccess --- include/genn/genn/varAccess.h | 56 +++++++++---------- src/genn/genn/currentSourceModels.cc | 2 +- .../genn/customConnectivityUpdateModels.cc | 6 +- src/genn/genn/neuronModels.cc | 2 +- src/genn/genn/postsynapticModels.cc | 2 +- src/genn/genn/weightUpdateModels.cc | 6 +- 6 files changed, 37 insertions(+), 37 deletions(-) diff --git a/include/genn/genn/varAccess.h b/include/genn/genn/varAccess.h index 5a3ffc50f1..91aae98245 100644 --- a/include/genn/genn/varAccess.h +++ b/include/genn/genn/varAccess.h @@ -69,29 +69,28 @@ enum class SynapseVarAccess : unsigned int //READ_ONLY_POST_DUPLICATE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::POST_NEURON) | static_cast(VarAccessDim::BATCH), }; -/*enum class CustomUpdateVarAccess : unsigned int +//! Supported combinations of access mode and dimension for custom update variables +/*! The axes are defined 'subtractively' ie VarAccessDim::BATCH indicates that this axis should be removed */ +enum class CustomUpdateVarAccess : unsigned int { - // Variables with matching shape - READ_WRITE, - READ_ONLY, + // Variables with same shape as groups custom update is attached to + READ_WRITE = static_cast(VarAccessMode::READ_WRITE), + READ_ONLY = static_cast(VarAccessMode::READ_ONLY), - // Variables shared across batches - READ_WRITE_SHARED, - READ_ONLY_SHARED, + // Variables which will be shared across batches if custom update is batched + READ_WRITE_SHARED = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDim::BATCH), + READ_ONLY_SHARED = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::BATCH), - - READ_WRITE_PRE, + // Variables which will be shared across neurons if per-neuron + READ_WRITE_SHARED_NEURON = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDim::NEURON), + READ_ONLY_SHARED_NEURON = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::NEURON), // Reduction variables - REDUCE_BATCH_SUM, - REDUCE_BATCH_MAX, - REDUCE_NEURON_SUM, - REDUCE_NEURON_MAX, - REDUCE_PRE_NEURON_SUM, - REDUCE_PRE_NEURON_MAX, - REDUCE_POST_NEURON_SUM, - REDUCE_POST_NEURON_MAX, -}*/ + REDUCE_BATCH_SUM = static_cast(VarAccessMode::REDUCE_SUM) | static_cast(VarAccessDim::BATCH), + REDUCE_BATCH_MAX = static_cast(VarAccessMode::REDUCE_MAX) | static_cast(VarAccessDim::BATCH), + REDUCE_NEURON_SUM = static_cast(VarAccessMode::REDUCE_SUM) | static_cast(VarAccessDim::NEURON), + REDUCE_NEURON_MAX = static_cast(VarAccessMode::REDUCE_MAX) | static_cast(VarAccessDim::NEURON), +}; //---------------------------------------------------------------------------- // Operators @@ -124,6 +123,8 @@ GENN_EXPORT class VarAccess {} VarAccess(SynapseVarAccess s) : m_Access{s} {} + VarAccess(CustomUpdateVarAccess c) : m_Access{c} + {} //------------------------------------------------------------------------ // Public API @@ -142,16 +143,15 @@ GENN_EXPORT class VarAccess return static_cast(val & ~0x1F); } - //! Returns true if this VarAccess would be valid for a neuron - bool isValidNeuron() const - { - return !std::holds_alternative(m_Access); - } - - //! Returns true if this VarAccess would be valid for a synapse - bool isValidSynapse() const + template + bool isValid() const { - return !std::holds_alternative(m_Access); + return std::visit( + Utils::Overload{ + [](std::monostate) { return true; }, + [](V v) { return true; }, + [](auto) { return false; }}, + m_Access); } void updateHash(boost::uuids::detail::sha1 &hash) const @@ -175,7 +175,7 @@ GENN_EXPORT class VarAccess //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - std::variant m_Access; + std::variant m_Access; }; } // namespace GeNN diff --git a/src/genn/genn/currentSourceModels.cc b/src/genn/genn/currentSourceModels.cc index 219421a3fa..672f013249 100644 --- a/src/genn/genn/currentSourceModels.cc +++ b/src/genn/genn/currentSourceModels.cc @@ -35,7 +35,7 @@ void Base::validate(const std::unordered_map ¶mValues, // If any variables have a reduction access mode, give an error const auto vars = getVars(); if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v){ return !v.access.isValidNeuron(); })) + [](const Models::Base::Var &v){ return !v.access.template isValid(); })) { throw std::runtime_error("Current source model variables much have NeuronVarAccess access type"); } diff --git a/src/genn/genn/customConnectivityUpdateModels.cc b/src/genn/genn/customConnectivityUpdateModels.cc index c5425ee846..8469b9a699 100644 --- a/src/genn/genn/customConnectivityUpdateModels.cc +++ b/src/genn/genn/customConnectivityUpdateModels.cc @@ -59,17 +59,17 @@ void Base::validate(const std::unordered_map ¶mValues, // Check variables have suitable access types if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v){ return !v.access.isValidSynapse(); })) + [](const Models::Base::Var &v){ return !v.access.template isValid(); })) { throw std::runtime_error("Custom connectivity update models variables much have SynapseVarAccess access type"); } if(std::any_of(preVars.cbegin(), preVars.cend(), - [](const Models::Base::Var &v){ return !v.access.isValidNeuron(); })) + [](const Models::Base::Var &v){ return !v.access.template isValid(); })) { throw std::runtime_error("Custom connectivity update models presynaptic variables much have NeuronVarAccess access type"); } if(std::any_of(postVars.cbegin(), postVars.cend(), - [](const Models::Base::Var &v){ return !v.access.isValidNeuron(); })) + [](const Models::Base::Var &v){ return !v.access.template isValid(); })) { throw std::runtime_error("Custom connectivity update models postsynaptic variables much have NeuronVarAccess access type"); } diff --git a/src/genn/genn/neuronModels.cc b/src/genn/genn/neuronModels.cc index 3205d6d924..31ec988c2d 100644 --- a/src/genn/genn/neuronModels.cc +++ b/src/genn/genn/neuronModels.cc @@ -50,7 +50,7 @@ void Base::validate(const std::unordered_map ¶mValues, // If any variables have a reduction access mode, give an error const auto vars = getVars(); if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v){ return !v.access.isValidNeuron(); })) + [](const Models::Base::Var &v){ return !v.access.template isValid(); })) { throw std::runtime_error("Neuron model variables much have NeuronVarAccess access type"); } diff --git a/src/genn/genn/postsynapticModels.cc b/src/genn/genn/postsynapticModels.cc index bf92e3976d..3ba49e37b5 100644 --- a/src/genn/genn/postsynapticModels.cc +++ b/src/genn/genn/postsynapticModels.cc @@ -36,7 +36,7 @@ void Base::validate(const std::unordered_map ¶mValues, // If any variables have a reduction access mode, give an error const auto vars = getVars(); if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v){ return !v.access.isValidNeuron(); })) + [](const Models::Base::Var &v){ return !v.access.template isValid(); })) { throw std::runtime_error("Postsynaptic model variables much have NeuronVarAccess access type"); } diff --git a/src/genn/genn/weightUpdateModels.cc b/src/genn/genn/weightUpdateModels.cc index 7ecd28c33b..cd8ec96393 100644 --- a/src/genn/genn/weightUpdateModels.cc +++ b/src/genn/genn/weightUpdateModels.cc @@ -90,17 +90,17 @@ void Base::validate(const std::unordered_map ¶mValues, // Check variables have suitable access types const auto vars = getVars(); if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v){ return !v.access.isValidSynapse(); })) + [](const Models::Base::Var &v){ return !v.access.template isValid(); })) { throw std::runtime_error("Weight update models variables much have SynapseVarAccess access type"); } if(std::any_of(preVars.cbegin(), preVars.cend(), - [](const Models::Base::Var &v){ return !v.access.isValidNeuron(); })) + [](const Models::Base::Var &v){ return !v.access.template isValid(); })) { throw std::runtime_error("Weight update models presynaptic variables much have NeuronVarAccess access type"); } if(std::any_of(postVars.cbegin(), postVars.cend(), - [](const Models::Base::Var &v){ return !v.access.isValidNeuron(); })) + [](const Models::Base::Var &v){ return !v.access.template isValid(); })) { throw std::runtime_error("Weight update models postsynaptic variables much have NeuronVarAccess access type"); } From b41377131c736159619e7c0356429d1fd45628b9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 18 Aug 2023 13:35:42 +0100 Subject: [PATCH 08/60] fixed typos --- src/genn/genn/currentSourceModels.cc | 2 +- src/genn/genn/customConnectivityUpdateModels.cc | 6 +++--- src/genn/genn/neuronModels.cc | 2 +- src/genn/genn/postsynapticModels.cc | 2 +- src/genn/genn/weightUpdateModels.cc | 6 +++--- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/genn/genn/currentSourceModels.cc b/src/genn/genn/currentSourceModels.cc index 672f013249..9bd9c3e75f 100644 --- a/src/genn/genn/currentSourceModels.cc +++ b/src/genn/genn/currentSourceModels.cc @@ -37,7 +37,7 @@ void Base::validate(const std::unordered_map ¶mValues, if(std::any_of(vars.cbegin(), vars.cend(), [](const Models::Base::Var &v){ return !v.access.template isValid(); })) { - throw std::runtime_error("Current source model variables much have NeuronVarAccess access type"); + throw std::runtime_error("Current source model variables must have NeuronVarAccess access type"); } } } // namespace GeNN::CurrentSourceModels \ No newline at end of file diff --git a/src/genn/genn/customConnectivityUpdateModels.cc b/src/genn/genn/customConnectivityUpdateModels.cc index 8469b9a699..6715dc6ed1 100644 --- a/src/genn/genn/customConnectivityUpdateModels.cc +++ b/src/genn/genn/customConnectivityUpdateModels.cc @@ -61,17 +61,17 @@ void Base::validate(const std::unordered_map ¶mValues, if(std::any_of(vars.cbegin(), vars.cend(), [](const Models::Base::Var &v){ return !v.access.template isValid(); })) { - throw std::runtime_error("Custom connectivity update models variables much have SynapseVarAccess access type"); + throw std::runtime_error("Custom connectivity update models variables must have SynapseVarAccess access type"); } if(std::any_of(preVars.cbegin(), preVars.cend(), [](const Models::Base::Var &v){ return !v.access.template isValid(); })) { - throw std::runtime_error("Custom connectivity update models presynaptic variables much have NeuronVarAccess access type"); + throw std::runtime_error("Custom connectivity update models presynaptic variables must have NeuronVarAccess access type"); } if(std::any_of(postVars.cbegin(), postVars.cend(), [](const Models::Base::Var &v){ return !v.access.template isValid(); })) { - throw std::runtime_error("Custom connectivity update models postsynaptic variables much have NeuronVarAccess access type"); + throw std::runtime_error("Custom connectivity update models postsynaptic variables must have NeuronVarAccess access type"); } } } // namespace GeNN::CustomConnectivityUpdateModels \ No newline at end of file diff --git a/src/genn/genn/neuronModels.cc b/src/genn/genn/neuronModels.cc index 31ec988c2d..6fdd7c210a 100644 --- a/src/genn/genn/neuronModels.cc +++ b/src/genn/genn/neuronModels.cc @@ -52,7 +52,7 @@ void Base::validate(const std::unordered_map ¶mValues, if(std::any_of(vars.cbegin(), vars.cend(), [](const Models::Base::Var &v){ return !v.access.template isValid(); })) { - throw std::runtime_error("Neuron model variables much have NeuronVarAccess access type"); + throw std::runtime_error("Neuron model variables must have NeuronVarAccess access type"); } } } // namespace GeNN::NeuronModels diff --git a/src/genn/genn/postsynapticModels.cc b/src/genn/genn/postsynapticModels.cc index 3ba49e37b5..65bcb36e37 100644 --- a/src/genn/genn/postsynapticModels.cc +++ b/src/genn/genn/postsynapticModels.cc @@ -38,7 +38,7 @@ void Base::validate(const std::unordered_map ¶mValues, if(std::any_of(vars.cbegin(), vars.cend(), [](const Models::Base::Var &v){ return !v.access.template isValid(); })) { - throw std::runtime_error("Postsynaptic model variables much have NeuronVarAccess access type"); + throw std::runtime_error("Postsynaptic model variables must have NeuronVarAccess access type"); } } } // namespace GeNN::PostsynapticModels \ No newline at end of file diff --git a/src/genn/genn/weightUpdateModels.cc b/src/genn/genn/weightUpdateModels.cc index cd8ec96393..d47f33145b 100644 --- a/src/genn/genn/weightUpdateModels.cc +++ b/src/genn/genn/weightUpdateModels.cc @@ -92,17 +92,17 @@ void Base::validate(const std::unordered_map ¶mValues, if(std::any_of(vars.cbegin(), vars.cend(), [](const Models::Base::Var &v){ return !v.access.template isValid(); })) { - throw std::runtime_error("Weight update models variables much have SynapseVarAccess access type"); + throw std::runtime_error("Weight update models variables must have SynapseVarAccess access type"); } if(std::any_of(preVars.cbegin(), preVars.cend(), [](const Models::Base::Var &v){ return !v.access.template isValid(); })) { - throw std::runtime_error("Weight update models presynaptic variables much have NeuronVarAccess access type"); + throw std::runtime_error("Weight update models presynaptic variables must have NeuronVarAccess access type"); } if(std::any_of(postVars.cbegin(), postVars.cend(), [](const Models::Base::Var &v){ return !v.access.template isValid(); })) { - throw std::runtime_error("Weight update models postsynaptic variables much have NeuronVarAccess access type"); + throw std::runtime_error("Weight update models postsynaptic variables must have NeuronVarAccess access type"); } } } // namespace WeightUpdateModels \ No newline at end of file From 4b43b8ad07eb430121d4ac7d91de724fa735e574 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 18 Aug 2023 13:51:27 +0100 Subject: [PATCH 09/60] flipped test in CustomUpdateBase::isReduction to match new VarAccess type --- include/genn/genn/customUpdate.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index d7fb8b5e26..86a0a67bde 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -89,13 +89,13 @@ class GENN_EXPORT CustomUpdateBase bool isReduction(const std::unordered_map &varRefs, VarAccessDim reduceDim) const { - // Return true if any variables have REDUCE flag in their access mode and don't have reduction dimension + // Return true if any variables have REDUCE flag in their access mode and have reduction dimension const auto vars = getCustomUpdateModel()->getVars(); if(std::any_of(vars.cbegin(), vars.cend(), [reduceDim](const Models::Base::Var &v) { return ((v.access & VarAccessModeAttribute::REDUCE) - && !(v.access.template getDims() & reduceDim)); + && (v.access.template getDims() & reduceDim)); })) { return true; @@ -104,10 +104,10 @@ class GENN_EXPORT CustomUpdateBase // Loop through all variable references for(const auto &modelVarRef : getCustomUpdateModel()->getVarRefs()) { // If custom update model reduces into this variable reference - // and the variable it targets doesn't have reduction dimension + // and the variable it targets has reduction dimension const auto &varRef = varRefs.at(modelVarRef.name); if ((modelVarRef.access & VarAccessModeAttribute::REDUCE) - && !(varRef.getVar().access.template getDims() & reduceDim)) + && (varRef.getVar().access.template getDims() & reduceDim)) { return true; } From a4b872da58b822272712e4c064b6db074f59d562 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 18 Aug 2023 13:51:42 +0100 Subject: [PATCH 10/60] fixed typo --- src/genn/genn/neuronModels.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/genn/neuronModels.cc b/src/genn/genn/neuronModels.cc index 6fdd7c210a..ac03382abb 100644 --- a/src/genn/genn/neuronModels.cc +++ b/src/genn/genn/neuronModels.cc @@ -47,7 +47,7 @@ void Base::validate(const std::unordered_map ¶mValues, Utils::validateVecNames(getAdditionalInputVars(), "Additional input variable"); - // If any variables have a reduction access mode, give an error + // If any variables have an invalid access mode, give an error const auto vars = getVars(); if(std::any_of(vars.cbegin(), vars.cend(), [](const Models::Base::Var &v){ return !v.access.template isValid(); })) From 7ba267496e6b000ed944eb7d883d85481fbf3fba Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 18 Aug 2023 13:52:05 +0100 Subject: [PATCH 11/60] correct test for custom update model validity --- src/genn/genn/customUpdateModels.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/genn/genn/customUpdateModels.cc b/src/genn/genn/customUpdateModels.cc index e2c05ea599..d0f0018a66 100644 --- a/src/genn/genn/customUpdateModels.cc +++ b/src/genn/genn/customUpdateModels.cc @@ -55,5 +55,13 @@ void Base::validate(const std::unordered_map ¶mValues, // Validate variable reference initialisers Utils::validateInitialisers(varRefs, varRefTargets, "Variable reference", description); Utils::validateVecNames(getExtraGlobalParamRefs(), "Extra global parameter reference"); + + // If any variables have an invalid access mode, give an error + const auto vars = getVars(); + if(std::any_of(vars.cbegin(), vars.cend(), + [](const Models::Base::Var &v){ return !v.access.template isValid(); })) + { + throw std::runtime_error("Custom update model variables must have CustomUpdateVarAccess access type"); + } } } // namespace GeNN::CustomUpdateModels \ No newline at end of file From 1a60ad87e56f96a21bf9d0fe4ddb44eb2d2b98a7 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 18 Aug 2023 13:52:28 +0100 Subject: [PATCH 12/60] Renamed ``Models::checkVarReferences`` to ``Models::checkVarReferenceTypes`` to reflect all it now does --- include/genn/genn/models.h | 9 +-------- src/genn/genn/customConnectivityUpdate.cc | 6 +++--- src/genn/genn/customUpdate.cc | 4 ++-- 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index a002f27e6a..5db306be4b 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -385,7 +385,7 @@ GENN_EXPORT void updateHash(const EGPReference &v, boost::uuids::detail::sha1 &h //! Helper function to check if variable reference types match those specified in model template -void checkVarReferences(const std::unordered_map &varRefs, const Base::VarRefVec &modelVarRefs) +void checkVarReferenceTypes(const std::unordered_map &varRefs, const Base::VarRefVec &modelVarRefs) { // Loop through all variable references for(const auto &modelVarRef : modelVarRefs) { @@ -396,13 +396,6 @@ void checkVarReferences(const std::unordered_map &varRefs, const if(varRef.getVar().type != modelVarRef.type) { throw std::runtime_error("Incompatible type for variable reference '" + modelVarRef.name + "'"); } - - // Check that no reduction targets reference duplicated variables - /*if((varRef.getVar().access.getDims() & VarAccessDuplication::DUPLICATE) - && (modelVarRef.access & VarAccessModeAttribute::REDUCE)) - { - throw std::runtime_error("Reduction target variable reference must be to SHARED or SHARED_NEURON variables."); - }*/ } } } // GeNN::Models diff --git a/src/genn/genn/customConnectivityUpdate.cc b/src/genn/genn/customConnectivityUpdate.cc index 0b4e105d25..8e35a1ccad 100644 --- a/src/genn/genn/customConnectivityUpdate.cc +++ b/src/genn/genn/customConnectivityUpdate.cc @@ -124,9 +124,9 @@ CustomConnectivityUpdate::CustomConnectivityUpdate(const std::string &name, cons } // Check variable reference types - Models::checkVarReferences(m_VarReferences, getCustomConnectivityUpdateModel()->getVarRefs()); - Models::checkVarReferences(m_PreVarReferences, getCustomConnectivityUpdateModel()->getPreVarRefs()); - Models::checkVarReferences(m_PostVarReferences, getCustomConnectivityUpdateModel()->getPostVarRefs()); + Models::checkVarReferenceTypes(m_VarReferences, getCustomConnectivityUpdateModel()->getVarRefs()); + Models::checkVarReferenceTypes(m_PreVarReferences, getCustomConnectivityUpdateModel()->getPreVarRefs()); + Models::checkVarReferenceTypes(m_PostVarReferences, getCustomConnectivityUpdateModel()->getPostVarRefs()); // Give error if any WU var references aren't pointing to synapse group if (std::any_of(m_VarReferences.cbegin(), m_VarReferences.cend(), diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index 40762b9339..9da7b2a5be 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -132,7 +132,7 @@ CustomUpdate::CustomUpdate(const std::string &name, const std::string &updateGro } // Check variable reference types - Models::checkVarReferences(m_VarReferences, getCustomUpdateModel()->getVarRefs()); + Models::checkVarReferenceTypes(m_VarReferences, getCustomUpdateModel()->getVarRefs()); // Update is per-neuron if any variables or variable reference targets have neuron dimension const auto modelVars = getCustomUpdateModel()->getVars(); @@ -251,7 +251,7 @@ CustomUpdateWU::CustomUpdateWU(const std::string &name, const std::string &updat } // Check variable reference types - Models::checkVarReferences(m_VarReferences, getCustomUpdateModel()->getVarRefs()); + Models::checkVarReferenceTypes(m_VarReferences, getCustomUpdateModel()->getVarRefs()); // Give error if references point to different synapse groups // **NOTE** this could be relaxed for dense From fa8712b9b247dd4b598648a2d99f69957eccb214 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 29 Aug 2023 10:15:29 +0100 Subject: [PATCH 13/60] fixed compiler errors --- .../backends/single_threaded_cpu/backend.h | 2 +- .../genn/genn/code_generator/backendBase.h | 24 +++---- .../customConnectivityUpdateGroupMerged.h | 2 +- .../code_generator/customUpdateGroupMerged.h | 12 ++-- include/genn/genn/customUpdate.h | 39 +++++------ include/genn/genn/customUpdateInternal.h | 5 +- include/genn/genn/models.h | 8 +-- include/genn/genn/varAccess.h | 13 +++- src/genn/backends/cuda/optimiser.cc | 6 +- .../backends/single_threaded_cpu/backend.cc | 34 +++++----- src/genn/genn/code_generator/backendBase.cc | 55 ++++++++-------- src/genn/genn/code_generator/backendSIMT.cc | 65 ++++++++++--------- .../customConnectivityUpdateGroupMerged.cc | 15 ++--- .../code_generator/customUpdateGroupMerged.cc | 39 +++++------ .../genn/code_generator/generateRunner.cc | 4 +- .../genn/code_generator/initGroupMerged.cc | 6 +- src/genn/genn/customConnectivityUpdate.cc | 4 +- src/genn/genn/customUpdate.cc | 34 +++------- src/genn/genn/models.cc | 44 +++++-------- 19 files changed, 199 insertions(+), 212 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 6134f54b41..e7d4bb585e 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -254,7 +254,7 @@ class BACKEND_EXPORT Backend : public BackendBase // If variable is a reduction target, copy value from register straight back into global memory if(v.access & VarAccessModeAttribute::REDUCE) { const std::string idx = env.getName(idxName); - env.getStream() << "group->" << v.name << "[" << cg.getVarIndex(v.access.template getDims(), idx) << "] = " << env[v.name] << ";" << std::endl; + env.getStream() << "group->" << v.name << "[" << cg.getVarIndex(1, v.access.template getDims(), idx) << "] = " << env[v.name] << ";" << std::endl; } } diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 9ae0667197..cd1f9aa390 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -495,16 +495,16 @@ class GENN_EXPORT BackendBase void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; - void buildStandardEnvironment(EnvironmentGroupMergedField &env) const; - void buildStandardEnvironment(EnvironmentGroupMergedField &env) const; - void buildStandardEnvironment(EnvironmentGroupMergedField &env) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; void buildStandardEnvironment(EnvironmentGroupMergedField &env) const; void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; - void buildStandardEnvironment(EnvironmentGroupMergedField &env) const; - void buildStandardEnvironment(EnvironmentGroupMergedField &env) const; - void buildStandardEnvironment(EnvironmentGroupMergedField &env) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; void buildStandardEnvironment(EnvironmentGroupMergedField &env) const; void buildStandardEnvironment(EnvironmentGroupMergedField &env) const; void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; @@ -550,19 +550,21 @@ class GENN_EXPORT BackendBase //! Helper function to generate initialisation code for any reduction operations carried out be custom update group. //! Returns vector of ReductionTarget structs, providing all information to write back reduction results to memory - std::vector genInitReductionTargets(CodeStream &os, const CustomUpdateGroupMerged &cg, const std::string &idx = "") const; + std::vector genInitReductionTargets(CodeStream &os, const CustomUpdateGroupMerged &cg, + unsigned int batchSize, const std::string &idx = "") const; //! Helper function to generate initialisation code for any reduction operations carried out be custom weight update group. //! //! Returns vector of ReductionTarget structs, providing all information to write back reduction results to memory - std::vector genInitReductionTargets(CodeStream &os, const CustomUpdateWUGroupMerged &cg, const std::string &idx = "") const; + std::vector genInitReductionTargets(CodeStream &os, const CustomUpdateWUGroupMerged &cg, + unsigned int batchSize, const std::string &idx = "") const; private: //-------------------------------------------------------------------------- // Private API //-------------------------------------------------------------------------- template - std::vector genInitReductionTargets(CodeStream &os, const G &cg, const std::string &idx, - R getVarRefIndexFn) const + std::vector genInitReductionTargets(CodeStream &os, const G &cg, unsigned int batchSize, + const std::string &idx, R getVarRefIndexFn) const { // Loop through variables std::vector reductionTargets; @@ -573,7 +575,7 @@ class GENN_EXPORT BackendBase const auto resolvedType = v.type.resolve(cg.getTypeContext()); os << resolvedType.getName() << " _lr" << v.name << " = " << getReductionInitialValue(v.access, resolvedType) << ";" << std::endl; reductionTargets.push_back({v.name, resolvedType, v.access, - cg.getVarIndex(v.access.template getDims(), idx)}); + cg.getVarIndex(batchSize, v.access.template getDims(), idx)}); } } diff --git a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h index 36a10362ee..2f2d96080b 100644 --- a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h @@ -94,7 +94,7 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public GroupMerged genPostamble); - std::string getVarIndex(VarAccessDim varDims, const std::string &index) const; - std::string getVarRefIndex(bool delay, VarAccessDim varDims, const std::string &index) const; + std::string getVarIndex(unsigned int batchSize, VarAccessDim varDims, const std::string &index) const; + std::string getVarRefIndex(bool delay, unsigned int batchSize, VarAccessDim varDims, const std::string &index) const; //---------------------------------------------------------------------------- // Static constants @@ -64,11 +64,11 @@ class GENN_EXPORT CustomUpdateWUGroupMergedBase : public GroupMerged genPostamble); - std::string getVarIndex(VarAccessDim varDims, const std::string &index) const; - std::string getVarRefIndex(VarAccessDim varDims, const std::string &index) const; + std::string getVarIndex(unsigned int batchSize, VarAccessDim varDims, const std::string &index) const; + std::string getVarRefIndex(unsigned int batchSize, VarAccessDim varDims, const std::string &index) const; }; diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index 86a0a67bde..903276e82d 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -70,8 +70,8 @@ class GENN_EXPORT CustomUpdateBase bool isZeroCopyEnabled() const; - //! Is this custom update batched i.e. run in parallel across model batches - bool isBatched() const { return m_Batched; } + //! Get dimensions of this custom update + VarAccessDim getDims() const{ return m_Dims; } //! Updates hash with custom update /*! NOTE: this can only be called after model is finalized */ @@ -118,27 +118,24 @@ class GENN_EXPORT CustomUpdateBase //! Helper function to check if variable reference types match those specified in model template - void checkVarReferenceBatching(const std::unordered_map& varRefs, unsigned int batchSize) + void checkVarReferenceDims(const std::unordered_map& varRefs, unsigned int batchSize) { - // If target of any variable references is duplicated, custom update should be batched - if(batchSize > 1) { - m_Batched = std::any_of(varRefs.cbegin(), varRefs.cend(), - [](const auto &v) { return v.second.isDuplicated(); }); - } - else { - m_Batched = false; + // Loop through variable references and or together their dimensions to get dimensionality of update + m_Dims = VarAccessDim{0}; + for(const auto &v : varRefs) { + m_Dims = m_Dims | v.second.getDims(); } // Loop through all variable references for(const auto &modelVarRef : getCustomUpdateModel()->getVarRefs()) { const auto varRef = varRefs.at(modelVarRef.name); - // If custom update is batched, check that any variable references to variables that aren't batched are read-only - // **NOTE** if custom update isn't batched, it's totally fine to write to shared variables - if(m_Batched && !(varRef.getVar().access.template getDims() & VarAccessDim::BATCH) + // If the shape of the references variable doesn't match the dimensionality + // of the custom update, check its access mode isn't read-write + if((m_Dims != varRef.getVar().access.template getDims()) && (modelVarRef.access == VarAccessMode::READ_WRITE)) { - throw std::runtime_error("Variable references to non-batched variables in batched custom updates cannot be read-write."); + throw std::runtime_error("Variable references to lower-dimensional variables cannot be read-write."); } } } @@ -165,11 +162,11 @@ class GENN_EXPORT CustomUpdateBase //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - const std::string m_Name; - const std::string m_UpdateGroupName; + std::string m_Name; + std::string m_UpdateGroupName; const CustomUpdateModels::Base *m_CustomUpdateModel; - const std::unordered_map m_Params; + std::unordered_map m_Params; std::unordered_map m_DerivedParams; std::unordered_map m_VarInitialisers; @@ -184,8 +181,8 @@ class GENN_EXPORT CustomUpdateBase //! Tokens produced by scanner from update code std::vector m_UpdateCodeTokens; - //! Is this custom update batched i.e. run in parallel across model batches - bool m_Batched; + //! Dimensions of this custom update + VarAccessDim m_Dims; }; //---------------------------------------------------------------------------- @@ -268,7 +265,6 @@ class GENN_EXPORT CustomUpdate : public CustomUpdateBase //------------------------------------------------------------------------ bool isBatchReduction() const { return isReduction(getVarReferences(), VarAccessDim::BATCH); } bool isNeuronReduction() const { return isReduction(getVarReferences(), VarAccessDim::NEURON); } - bool isPerNeuron() const{ return m_PerNeuron; } const NeuronGroup *getDelayNeuronGroup() const { return m_DelayNeuronGroup; } @@ -293,9 +289,6 @@ class GENN_EXPORT CustomUpdate : public CustomUpdateBase const std::unordered_map m_VarReferences; const unsigned int m_Size; const NeuronGroup *m_DelayNeuronGroup; - - //! Is this custom update per-neuron i.e. run in parallel across all neurons - bool m_PerNeuron; }; //------------------------------------------------------------------------ diff --git a/include/genn/genn/customUpdateInternal.h b/include/genn/genn/customUpdateInternal.h index bf6d4e416d..e8d975e5d1 100644 --- a/include/genn/genn/customUpdateInternal.h +++ b/include/genn/genn/customUpdateInternal.h @@ -26,8 +26,7 @@ class CustomUpdateInternal : public CustomUpdate using CustomUpdateBase::getDerivedParams; using CustomUpdateBase::isInitRNGRequired; using CustomUpdateBase::isZeroCopyEnabled; - using CustomUpdateBase::isBatched; - using CustomUpdate::isPerNeuron; + using CustomUpdateBase::getDims; using CustomUpdateBase::getVarLocationHashDigest; using CustomUpdateBase::getUpdateCodeTokens; @@ -86,7 +85,7 @@ class CustomUpdateWUInternal : public CustomUpdateWU using CustomUpdateBase::getDerivedParams; using CustomUpdateBase::isInitRNGRequired; using CustomUpdateBase::isZeroCopyEnabled; - using CustomUpdateBase::isBatched; + using CustomUpdateBase::getDims; using CustomUpdateBase::isReduction; using CustomUpdateBase::getVarLocationHashDigest; using CustomUpdateBase::getUpdateCodeTokens; diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 5db306be4b..bae2c03d35 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -201,8 +201,8 @@ class GENN_EXPORT VarReference : public VarReferenceBase // **TODO** rename to getNameSuffix std::string getTargetName() const; - //! If model is batched, will the variable this is referencing be duplicated? - bool isDuplicated() const; + //! Get dimensions of variable being referenced + VarAccessDim getDims() const; //! If this reference points to another custom update, return pointer to it /*! This is used to detect circular dependencies */ @@ -270,8 +270,8 @@ class GENN_EXPORT WUVarReference : public VarReferenceBase // **TODO** rename to getNameSuffix std::string getTargetName() const; - //! If model is batched, will the variable this is referencing be duplicated? - bool isDuplicated() const; + //! Get dimensions of variable being referenced + VarAccessDim getDims() const; SynapseGroup *getSynapseGroup() const; diff --git a/include/genn/genn/varAccess.h b/include/genn/genn/varAccess.h index 91aae98245..649b9dc2df 100644 --- a/include/genn/genn/varAccess.h +++ b/include/genn/genn/varAccess.h @@ -110,6 +110,12 @@ inline VarAccessDim operator | (VarAccessDim a, VarAccessDim b) return static_cast(static_cast(a) | static_cast(b)); } + +inline VarAccessDim clearDim(VarAccessDim a, VarAccessDim b) +{ + return static_cast(static_cast(a) & ~static_cast(b)); +} + //---------------------------------------------------------------------------- // VarAccess //---------------------------------------------------------------------------- @@ -132,10 +138,14 @@ GENN_EXPORT class VarAccess template VarAccessDim getDims() const { + // Extract value const unsigned int val = std::visit( Utils::Overload{ + // If access is set to default, use READ_WRITE mode of typed var access e.g. NeuronVarAcccess::READ_WRITE [](std::monostate) { return static_cast(V::READ_WRITE); }, + // Otherwise, if stored type matches template type, use value [](V v) { return static_cast(v); }, + // Otherwise, give error [](auto)->unsigned int { throw std::runtime_error("Invalid var access type"); }}, m_Access); @@ -149,7 +159,7 @@ GENN_EXPORT class VarAccess return std::visit( Utils::Overload{ [](std::monostate) { return true; }, - [](V v) { return true; }, + [](V) { return true; }, [](auto) { return false; }}, m_Access); } @@ -164,6 +174,7 @@ GENN_EXPORT class VarAccess //------------------------------------------------------------------------ operator VarAccessMode() const { + // If access is set to default, access mode is always read-write otherwise mask out and cast access mode, bits return std::visit( Utils::Overload{ [](std::monostate) { return VarAccessMode::READ_WRITE; }, diff --git a/src/genn/backends/cuda/optimiser.cc b/src/genn/backends/cuda/optimiser.cc index ba03facae3..c76dc0a216 100644 --- a/src/genn/backends/cuda/optimiser.cc +++ b/src/genn/backends/cuda/optimiser.cc @@ -162,7 +162,7 @@ void calcGroupSizes(const CUDA::Preferences &preferences, const ModelSpecInterna // Loop through custom updates, add size to vector of custom update groups and update group name to set for(const auto &c : model.getCustomUpdates()) { - const size_t numCopies = (c.second.isBatched() && !c.second.isBatchReduction()) ? model.getBatchSize() : 1; + const size_t numCopies = ((c.second.getDims() & VarAccessDim::BATCH) && !c.second.isBatchReduction()) ? model.getBatchSize() : 1; const size_t size = numCopies * (c.second.isNeuronReduction() ? 32 : c.second.getSize()); groupSizes[KernelCustomUpdate].push_back(size); @@ -176,7 +176,7 @@ void calcGroupSizes(const CUDA::Preferences &preferences, const ModelSpecInterna for(const auto &c : model.getCustomWUUpdates()) { const SynapseGroupInternal *sgInternal = static_cast(c.second.getSynapseGroup()); if(c.second.isTransposeOperation()) { - const size_t numCopies = c.second.isBatched() ? model.getBatchSize() : 1; + const size_t numCopies = (c.second.getDims() & VarAccessDim::BATCH) ? model.getBatchSize() : 1; const size_t size = numCopies * sgInternal->getSrcNeuronGroup()->getNumNeurons() * sgInternal->getTrgNeuronGroup()->getNumNeurons(); groupSizes[KernelCustomTransposeUpdate].push_back(size); customTransposeUpdateKernels.insert(c.second.getUpdateGroupName()); @@ -184,7 +184,7 @@ void calcGroupSizes(const CUDA::Preferences &preferences, const ModelSpecInterna else { customUpdateKernels.insert(c.second.getUpdateGroupName()); - const size_t numCopies = (c.second.isBatched() && !c.second.isBatchReduction()) ? model.getBatchSize() : 1; + const size_t numCopies = ((c.second.getDims() & VarAccessDim::BATCH) && !c.second.isBatchReduction()) ? model.getBatchSize() : 1; if(sgInternal->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { groupSizes[KernelCustomUpdate].push_back(numCopies * sgInternal->getSrcNeuronGroup()->getNumNeurons() * sgInternal->getMaxConnections()); } diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 01b6331a22..722d3ddcae 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -574,16 +574,16 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back // Create matching environment EnvironmentGroupMergedField groupEnv(funcEnv, c); buildSizeEnvironment(groupEnv); - buildStandardEnvironment(groupEnv); + buildStandardEnvironment(groupEnv, 1); if (c.getArchetype().isNeuronReduction()) { // Initialise reduction targets // **TODO** these should be provided with some sort of caching mechanism - const auto reductionTargets = genInitReductionTargets(groupEnv.getStream(), c); + const auto reductionTargets = genInitReductionTargets(groupEnv.getStream(), c, 1); // Loop through group members EnvironmentGroupMergedField memberEnv(groupEnv, c); - if (c.getArchetype().isPerNeuron()) { + if (c.getArchetype().getDims() & VarAccessDim::NEURON) { memberEnv.print("for(unsigned int i = 0; i < $(size); i++)"); memberEnv.add(Type::Uint32.addConst(), "id", "i"); } @@ -592,7 +592,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back } { CodeStream::Scope b(memberEnv.getStream()); - c.generateCustomUpdate(*this, memberEnv, + c.generateCustomUpdate(*this, memberEnv, 1, [&reductionTargets, this](auto &env, auto&) { // Loop through reduction targets and generate reduction @@ -611,7 +611,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back else { // Loop through group members EnvironmentGroupMergedField memberEnv(groupEnv, c); - if (c.getArchetype().isPerNeuron()) { + if (c.getArchetype().getDims() & VarAccessDim::NEURON) { memberEnv.print("for(unsigned int i = 0; i < $(size); i++)"); memberEnv.add(Type::Uint32.addConst(), "id", "i"); } @@ -622,7 +622,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back CodeStream::Scope b(memberEnv.getStream()); // Generate custom update - c.generateCustomUpdate(*this, memberEnv, + c.generateCustomUpdate(*this, memberEnv, 1, [this](auto &env, auto &c) { // Write back reductions @@ -651,7 +651,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back // Create matching environment EnvironmentGroupMergedField groupEnv(funcEnv, c); buildSizeEnvironment(groupEnv); - buildStandardEnvironment(groupEnv); + buildStandardEnvironment(groupEnv, 1); // **TODO** add fields const SynapseGroupInternal *sg = c.getArchetype().getSynapseGroup(); @@ -660,7 +660,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back [&c, this](EnvironmentExternalBase &env) { // Call custom update handler - c.generateCustomUpdate(*this, env, + c.generateCustomUpdate(*this, env, 1, [this](auto &env, CustomUpdateWUGroupMergedBase &c) { // Write back reductions @@ -710,7 +710,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back } // Generate custom update - c.generateCustomUpdate(*this, synEnv, + c.generateCustomUpdate(*this, synEnv, 1, [this](auto &env, auto &c) { // Write back reductions @@ -778,7 +778,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back // Create matching environment EnvironmentGroupMergedField groupEnv(funcEnv, c); buildSizeEnvironment(groupEnv); - buildStandardEnvironment(groupEnv); + buildStandardEnvironment(groupEnv, 1); // Add field for transpose field and get its name const std::string transposeVarName = c.addTransposeField(*this, groupEnv); @@ -804,7 +804,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back // Generate custom update c.generateCustomUpdate( - *this, synEnv, + *this, synEnv, 1, [&transposeVarName, this](auto &env, const auto&) { // Update transpose variable @@ -922,7 +922,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: funcEnv.getStream() << "const auto *group = &mergedCustomUpdateInitGroup" << c.getIndex() << "[g]; " << std::endl; EnvironmentGroupMergedField groupEnv(funcEnv, c); - buildStandardEnvironment(groupEnv); + buildStandardEnvironment(groupEnv, 1); c.generateInit(*this, groupEnv, 1); } }); @@ -984,7 +984,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: funcEnv.getStream() << "const auto *group = &mergedCustomWUUpdateInitGroup" << c.getIndex() << "[g]; " << std::endl; EnvironmentGroupMergedField groupEnv(funcEnv, c); - buildStandardEnvironment(groupEnv); + buildStandardEnvironment(groupEnv, 1); c.generateInit(*this, groupEnv, 1); } }); @@ -1210,7 +1210,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Get reference to group funcEnv.getStream() << "const auto *group = &mergedCustomWUUpdateSparseInitGroup" << c.getIndex() << "[g]; " << std::endl; EnvironmentGroupMergedField groupEnv(funcEnv, c); - buildStandardEnvironment(groupEnv); + buildStandardEnvironment(groupEnv, 1); groupEnv.printLine("// Loop through presynaptic neurons"); groupEnv.print("for (unsigned int i = 0; i < $(num_pre); i++)"); @@ -2017,8 +2017,8 @@ void Backend::genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateG env, cg, idxName, [&cg](const Models::VarReference &varRef, const std::string &index) { - return cg.getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, - varRef.getVar().access.getDims(), index); + return cg.getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, 1, + varRef.getVar().access.getDims(), index); }); } //-------------------------------------------------------------------------- @@ -2028,7 +2028,7 @@ void Backend::genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateW env, cg, idxName, [&cg](const Models::WUVarReference &varRef, const std::string &index) { - return cg.getVarRefIndex(varRef.getVar().access.getDims(), index); + return cg.getVarRefIndex(1, varRef.getVar().access.getDims(), index); }); } } // namespace GeNN::CodeGenerator::SingleThreadedCPU diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 79eb63f4f4..8cd6598d62 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -373,10 +373,11 @@ void buildStandardSynapseEnvironment(const BackendBase &backend, EnvironmentGrou } //-------------------------------------------------------------------------- template -void buildStandardCustomUpdateEnvironment(const BackendBase &backend, EnvironmentGroupMergedField &env) +void buildStandardCustomUpdateEnvironment(const BackendBase &backend, EnvironmentGroupMergedField &env, unsigned int batchSize) { // If batching is enabled, calculate batch offset - if(env.getGroup().getArchetype().isBatched()) { + const bool batched = (env.getGroup().getArchetype().getDims() & VarAccessDim::BATCH) && (batchSize > 1); + if(batched) { env.add(Type::Uint32.addConst(), "_batch_offset", "batchOffset", {env.addInitialiser("const unsigned int batchOffset = $(size) * $(batch);")}); } @@ -397,7 +398,7 @@ void buildStandardCustomUpdateEnvironment(const BackendBase &backend, Environmen {env.addInitialiser("const unsigned int delayOffset = $(_delay_slot) * $(size);")}); // If batching is also enabled, calculate offset including delay and batch - if(env.getGroup().getArchetype().isBatched()) { + if(batched) { const std::string numDelaySlotsStr = std::to_string(env.getGroup().getArchetype().getDelayNeuronGroup()->getNumDelaySlots()); env.add(Type::Uint32.addConst(), "_batch_delay_slot", "batchDelaySlot", {env.addInitialiser("const unsigned int batchDelaySlot = ($(batch) * " + numDelaySlotsStr + ") + $(_delay_slot);")}); @@ -410,10 +411,10 @@ void buildStandardCustomUpdateEnvironment(const BackendBase &backend, Environmen } //-------------------------------------------------------------------------- template -void buildStandardCustomUpdateWUEnvironment(const BackendBase &backend, EnvironmentGroupMergedField &env) +void buildStandardCustomUpdateWUEnvironment(const BackendBase &backend, EnvironmentGroupMergedField &env, unsigned int batchSize) { // Add batch offset if group is batched - if(env.getGroup().getArchetype().isBatched()) { + if((env.getGroup().getArchetype().getDims() & VarAccessDim::BATCH) && (batchSize > 1)) { env.add(Type::Uint32.addConst(), "_batch_offset", "batchOffset", {env.addInitialiser("const unsigned int batchOffset = $(_size) * $(batch);")}); } @@ -571,19 +572,19 @@ void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const { - buildStandardCustomUpdateEnvironment(*this, env); + buildStandardCustomUpdateEnvironment(*this, env, batchSize); } //----------------------------------------------------------------------- -void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const { - buildStandardCustomUpdateWUEnvironment(*this, env); + buildStandardCustomUpdateWUEnvironment(*this, env, batchSize); } //----------------------------------------------------------------------- -void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const { - buildStandardCustomUpdateWUEnvironment(*this, env); + buildStandardCustomUpdateWUEnvironment(*this, env, batchSize); } //----------------------------------------------------------------------- void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const @@ -601,22 +602,22 @@ void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const { buildCustomUpdateSizeEnvironment(env); - buildStandardCustomUpdateEnvironment(*this, env); + buildStandardCustomUpdateEnvironment(*this, env, batchSize); } //----------------------------------------------------------------------- -void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const { buildCustomUpdateWUSizeEnvironment(*this, env); - buildStandardCustomUpdateWUEnvironment(*this, env); + buildStandardCustomUpdateWUEnvironment(*this, env, batchSize); } //----------------------------------------------------------------------- -void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const { buildCustomUpdateWUSizeEnvironment(*this, env); - buildStandardCustomUpdateWUEnvironment(*this, env); + buildStandardCustomUpdateWUEnvironment(*this, env, batchSize); } //----------------------------------------------------------------------- void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const @@ -697,24 +698,26 @@ std::string BackendBase::getReductionOperation(const std::string &reduction, con } } //----------------------------------------------------------------------- -std::vector BackendBase::genInitReductionTargets(CodeStream &os, const CustomUpdateGroupMerged &cg, const std::string &idx) const +std::vector BackendBase::genInitReductionTargets(CodeStream &os, const CustomUpdateGroupMerged &cg, + unsigned int batchSize, const std::string &idx) const { return genInitReductionTargets( - os, cg, idx, - [&cg](const Models::VarReference &varRef, const std::string &index) + os, cg, batchSize, idx, + [batchSize, &cg](const Models::VarReference &varRef, const std::string &index) { - return cg.getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, - varRef.getVar().access.getDims(), index); + return cg.getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, batchSize, + varRef.getVar().access.getDims(), index); }); } //----------------------------------------------------------------------- -std::vector BackendBase::genInitReductionTargets(CodeStream &os, const CustomUpdateWUGroupMerged &cg, const std::string &idx) const +std::vector BackendBase::genInitReductionTargets(CodeStream &os, const CustomUpdateWUGroupMerged &cg, + unsigned int batchSize, const std::string &idx) const { return genInitReductionTargets( - os, cg, idx, - [&cg](const Models::WUVarReference &varRef, const std::string &index) + os, cg, batchSize, idx, + [batchSize, &cg](const Models::WUVarReference &varRef, const std::string &index) { - return cg.getVarRefIndex(varRef.getVar().access.getDims(), index); + return cg.getVarRefIndex(batchSize, varRef.getVar().access.getDims(), index); }); } } // namespace GeNN::CodeGenerator \ No newline at end of file diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index fc52106e0b..fb73b6309d 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -205,12 +205,12 @@ size_t BackendSIMT::getNumInitialisationRNGStreams(const ModelSpecMerged &modelM //-------------------------------------------------------------------------- size_t BackendSIMT::getPaddedNumCustomUpdateThreads(const CustomUpdateInternal &cg, unsigned int batchSize) const { - const size_t numCopies = (cg.isBatched() && !cg.isBatchReduction()) ? batchSize : 1; + const size_t numCopies = ((cg.getDims() & VarAccessDim::BATCH) && !cg.isBatchReduction()) ? batchSize : 1; if (cg.isNeuronReduction()) { return padKernelSize(32 * numCopies, KernelCustomUpdate); } - else if (cg.isPerNeuron()) { + else if (!(cg.getDims() & VarAccessDim::NEURON)) { return numCopies * padKernelSize(cg.getSize(), KernelCustomUpdate); } else { @@ -221,7 +221,7 @@ size_t BackendSIMT::getPaddedNumCustomUpdateThreads(const CustomUpdateInternal & size_t BackendSIMT::getPaddedNumCustomUpdateWUThreads(const CustomUpdateWUInternal &cg, unsigned int batchSize) const { const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); - const size_t numCopies = (cg.isBatched() && !cg.isBatchReduction()) ? batchSize : 1; + const size_t numCopies = ((cg.getDims() & VarAccessDim::BATCH) && !cg.isBatchReduction()) ? batchSize : 1; if(sgInternal->getMatrixType() & SynapseMatrixWeight::KERNEL) { return numCopies * padKernelSize(sgInternal->getKernelSizeFlattened(), KernelCustomUpdate); @@ -239,7 +239,7 @@ size_t BackendSIMT::getPaddedNumCustomUpdateTransposeWUThreads(const CustomUpdat const size_t paddedNumPre = padKernelSize(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), KernelCustomTransposeUpdate); const size_t paddedNumPost = padKernelSize(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons(), KernelCustomTransposeUpdate); - const size_t numCopies = cg.isBatched() ? batchSize : 1; + const size_t numCopies = (cg.getDims() & VarAccessDim::BATCH) ? batchSize : 1; return numCopies * paddedNumPre * paddedNumPost / getKernelBlockSize(KernelCustomTransposeUpdate); } //-------------------------------------------------------------------------- @@ -950,7 +950,8 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, ModelSpecMerge CodeStream::Scope b(groupEnv.getStream()); // Initialise reduction targets - const auto reductionTargets = genInitReductionTargets(groupEnv.getStream(), cg, groupEnv["id"]); + const auto reductionTargets = genInitReductionTargets(groupEnv.getStream(), cg, + batchSize, groupEnv["id"]); // Loop through batches // **TODO** this naive approach is good for reduction when there are lots of neurons/synapses but, @@ -960,11 +961,11 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, ModelSpecMerge CodeStream::Scope b(groupEnv.getStream()); EnvironmentGroupMergedField batchEnv(groupEnv, cg); batchEnv.add(Type::Uint32.addConst(), "batch", "batch"); - buildStandardEnvironment(batchEnv); + buildStandardEnvironment(batchEnv, batchSize); // **THINK** it would be great to 'lift' reads of SHARED variables out of this loop cg.generateCustomUpdate( - *this, batchEnv, + *this, batchEnv, batchSize, [&reductionTargets, this](auto &env, const auto&) { // Loop through reduction targets and generate reduction @@ -995,10 +996,10 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, ModelSpecMerge {groupEnv.addInitialiser("const unsigned int batch = $(id) / 32;")}); EnvironmentGroupMergedField batchEnv(groupEnv, cg); - buildStandardEnvironment(batchEnv); + buildStandardEnvironment(batchEnv, batchSize); // Initialise reduction targets - const auto reductionTargets = genInitReductionTargets(batchEnv.getStream(), cg); + const auto reductionTargets = genInitReductionTargets(batchEnv.getStream(), cg, batchSize); // Loop through warps of data // **TODO** this approach is good for reductions where there are small numbers of neurons but large batches sizes but, @@ -1012,7 +1013,7 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, ModelSpecMerge // **THINK** it would be great to 'lift' reads of NEURON_SHARED variables out of this loop cg.generateCustomUpdate( - *this, batchEnv, + *this, batchEnv, batchSize, [&reductionTargets, this](auto &env, const auto&) { // Loop through reduction targets and generate reduction @@ -1042,24 +1043,25 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, ModelSpecMerge } } // Otherwise, if this update isn't per-neuron - else if (!cg.getArchetype().isPerNeuron()) { + else if (cg.getArchetype().getDims() & VarAccessDim::NEURON) { // Use local ID for batch and always use zero for ID groupEnv.add(Type::Uint32.addConst(), "batch", "$(_id)"); groupEnv.add(Type::Uint32.addConst(), "id", "0"); groupEnv.getStream() << "// only do this for existing neurons" << std::endl; - groupEnv.getStream() << "if(" << groupEnv["batch"] << " < " << (cg.getArchetype().isBatched() ? batchSize : 1) << ")"; + groupEnv.getStream() << "if(" << groupEnv["batch"] << " < " << ((cg.getArchetype().getDims() & VarAccessDim::BATCH) ? batchSize : 1) << ")"; { CodeStream::Scope b(groupEnv.getStream()); EnvironmentGroupMergedField batchEnv(groupEnv, cg); - buildStandardEnvironment(batchEnv); + buildStandardEnvironment(batchEnv, batchSize); - cg.generateCustomUpdate(*this, batchEnv, [](auto&, auto&){}); + cg.generateCustomUpdate(*this, batchEnv, batchSize, + [](auto&, auto&){}); } } // Otherwise else { - if(cg.getArchetype().isBatched()) { + if((cg.getArchetype().getDims() & VarAccessDim::BATCH) && (batchSize > 1)) { // Split ID into intra-batch ID and batch // **TODO** fast-divide style optimisations here const std::string blockSizeStr = std::to_string(blockSize); @@ -1077,13 +1079,14 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, ModelSpecMerge } EnvironmentGroupMergedField batchEnv(groupEnv, cg); - buildStandardEnvironment(batchEnv); + buildStandardEnvironment(batchEnv, batchSize); batchEnv.getStream() << "// only do this for existing neurons" << std::endl; batchEnv.print("if($(id) < $(size))"); { CodeStream::Scope b(batchEnv.getStream()); - cg.generateCustomUpdate(*this, batchEnv, [](auto&, auto&){}); + cg.generateCustomUpdate(*this, batchEnv, batchSize, + [](auto&, auto&){}); } } }); @@ -1112,7 +1115,7 @@ void BackendSIMT::genCustomUpdateWUKernel(EnvironmentExternal &env, ModelSpecMer // If update isn't a batch reduction if(!cg.getArchetype().isBatchReduction()) { // If it's batched - if(cg.getArchetype().isBatched()) { + if((cg.getArchetype().getDims() & VarAccessDim::BATCH) && (batchSize > 1)) { // Split ID into intra-batch ID and batch // **TODO** fast-divide style optimisations here const std::string blockSizeStr = std::to_string(blockSize); @@ -1163,7 +1166,8 @@ void BackendSIMT::genCustomUpdateWUKernel(EnvironmentExternal &env, ModelSpecMer synEnv.add(Type::Uint32.addConst(), "id_syn", "$(id)"); // Initialise reduction targets - const auto reductionTargets = genInitReductionTargets(synEnv.getStream(), cg, synEnv["id_syn"]); + const auto reductionTargets = genInitReductionTargets(synEnv.getStream(), cg, + batchSize, synEnv["id_syn"]); // If this is a reduction if(cg.getArchetype().isBatchReduction()) { @@ -1178,10 +1182,10 @@ void BackendSIMT::genCustomUpdateWUKernel(EnvironmentExternal &env, ModelSpecMer // **NOTE** use scope to force batchEnv to generate all code within loop { EnvironmentGroupMergedField batchEnv(synEnv, cg); - buildStandardEnvironment(batchEnv); + buildStandardEnvironment(batchEnv, batchSize); cg.generateCustomUpdate( - *this, batchEnv, + *this, batchEnv, batchSize, [&reductionTargets, this](auto &env, auto &cg) { // If this is a reduction @@ -1216,12 +1220,13 @@ void BackendSIMT::genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, Mod BackendBase::MemorySpaces &memorySpaces, const std::string &updateGroup, size_t &idStart) const { // Generate 2D array + const unsigned int batchSize = modelMerged.getModel().getBatchSize(); const size_t blockSize = getKernelBlockSize(KernelCustomTransposeUpdate); env.getStream() << getSharedPrefix() << " float shTile[" << blockSize << "][" << (blockSize + 1) << "];" << std::endl; genParallelGroup( env, modelMerged, memorySpaces, updateGroup, idStart, &ModelSpecMerged::genMergedCustomUpdateTransposeWUGroups, - [&modelMerged, this](const CustomUpdateWUInternal &cu) { return getPaddedNumCustomUpdateTransposeWUThreads(cu, modelMerged.getModel().getBatchSize()); }, - [blockSize, this](EnvironmentExternalBase &env, CustomUpdateTransposeWUGroupMerged &cg) + [batchSize, &modelMerged, this](const CustomUpdateWUInternal &cu) { return getPaddedNumCustomUpdateTransposeWUThreads(cu, batchSize); }, + [batchSize, blockSize, this](EnvironmentExternalBase &env, CustomUpdateTransposeWUGroupMerged &cg) { EnvironmentGroupMergedField groupEnv(env, cg); buildSizeEnvironment(groupEnv); @@ -1232,7 +1237,7 @@ void BackendSIMT::genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, Mod // Calculate what block this kernel starts at (because of kernel merging, it may not start at block 0) groupEnv.getStream() << "const unsigned int blockStart = " << groupEnv["_group_start_id"] << " / " << blockSize << ";" << std::endl; - if(cg.getArchetype().isBatched()) { + if((cg.getArchetype().getDims() & VarAccessDim::BATCH) && (batchSize > 1)) { // If there's multiple batches we also need to know how many Y blocks and hence total blocks there are groupEnv.getStream() << "const unsigned int numYBlocks = (" << groupEnv["num_pre"] << " + " << (blockSize - 1) << ") / " << blockSize << ";" << std::endl; groupEnv.getStream() << "const unsigned int numBlocks = numXBlocks * numYBlocks;" << std::endl; @@ -1252,7 +1257,7 @@ void BackendSIMT::genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, Mod } EnvironmentGroupMergedField batchEnv(groupEnv, cg); - buildStandardEnvironment(batchEnv); + buildStandardEnvironment(batchEnv, batchSize); // Add field for transpose field and get its name const std::string transposeVarName = cg.addTransposeField(*this, batchEnv); @@ -1287,7 +1292,7 @@ void BackendSIMT::genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, Mod synEnv.add(Type::Uint32.addConst(), "id_syn", "idx", {synEnv.addInitialiser("const unsigned int idx = ((y + j) * $(num_post)) + x;")}); cg.generateCustomUpdate( - *this, synEnv, + *this, synEnv, batchSize, [&transposeVarName, this](auto &env, const auto&) { // Write forward weight to shared memory @@ -1317,7 +1322,7 @@ void BackendSIMT::genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, Mod { CodeStream::Scope b(batchEnv.getStream()); batchEnv.print("$(" + transposeVarName + "_transpose)["); - if(cg.getArchetype().isBatched()) { + if((cg.getArchetype().getDims() & VarAccessDim::BATCH) && (batchSize > 1)) { batchEnv.print("$(_batch_offset) + "); } batchEnv.printLine("((y + j) * $(num_pre)) + x] = shTile[" + getThreadID(0) + "][" + getThreadID(1) + " + j];"); @@ -1444,7 +1449,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer [batchSize, this](EnvironmentExternalBase &env, CustomUpdateInitGroupMerged &cg) { EnvironmentGroupMergedField groupEnv(env, cg); - buildStandardEnvironment(groupEnv); + buildStandardEnvironment(groupEnv, batchSize); groupEnv.getStream() << "// only do this for existing variables" << std::endl; groupEnv.print("if($(id) < $(size))"); @@ -1471,7 +1476,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer [batchSize, this](EnvironmentExternalBase &env, CustomWUUpdateInitGroupMerged &cg) { EnvironmentGroupMergedField groupEnv(env, cg); - buildStandardEnvironment(groupEnv); + buildStandardEnvironment(groupEnv, batchSize); const SynapseGroup *sg = cg.getArchetype().getSynapseGroup(); genSynapseVarInit(groupEnv, batchSize, cg, cg.getArchetype().isInitRNGRequired(), (sg->getMatrixType() & SynapseMatrixWeight::KERNEL), sg->getKernelSize().size()); @@ -1749,7 +1754,7 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS [batchSize, numInitializeThreads, this](EnvironmentExternalBase &env, CustomWUUpdateSparseInitGroupMerged &cg) { EnvironmentGroupMergedField groupEnv(env, cg); - buildStandardEnvironment(groupEnv); + buildStandardEnvironment(groupEnv, batchSize); // If this custom update requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index a46953d5d5..3544ff2b7e 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -170,7 +170,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back for(const auto &v : getArchetype().getCustomConnectivityUpdateModel()->getPreVarRefs()) { // If model isn't batched or variable isn't duplicated const auto &varRef = getArchetype().getPreVarReferences().at(v.name); - if(batchSize == 1 || !varRef.isDuplicated()) { + if(batchSize == 1 || !(varRef.getDims() & VarAccessDim::BATCH)) { // Determine index const std::string index = (varRef.getDelayNeuronGroup() != nullptr) ? "$(_pre_delay_offset) + $(id_pre)" : "$(id_pre)"; @@ -240,8 +240,8 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Use subsequent parameters to initialise new synapse's variables referenced via the custom connectivity update for (size_t i = 0; i < ccuVarRefs.size(); i++) { // If model is batched and this variable is duplicated - if (batchSize > 1 && getArchetype().getVarReferences().at(ccuVarRefs[i].name).isDuplicated()) - { + const auto &varRef = getArchetype().getVarReferences().at(ccuVarRefs[i].name); + if (batchSize > 1 && (varRef.getDims() & VarAccessDim::BATCH)) { // Copy parameter into a register (just incase it's e.g. a RNG call) and copy into all batches addSynapse << "const " << ccuVarRefs[i].type.resolve(getTypeContext()).getName() << " _" << ccuVarRefs[i].name << "Val = $(" << (1 + ccuVars.size() + i) << ");" << std::endl; addSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; @@ -261,8 +261,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Loop through any other dependent variables for (size_t i = 0; i < dependentVars.size(); i++) { // If model is batched and this dependent variable is duplicated - if (batchSize > 1 && dependentVars.at(i).isDuplicated()) - { + if (batchSize > 1 && (dependentVars.at(i).getDims() & VarAccessDim::BATCH)) { // Loop through all batches and zero addSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; { @@ -306,8 +305,8 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Loop through variable references for (size_t i = 0; i < ccuVarRefs.size(); i++) { // If model is batched and this variable is duplicated - if (batchSize > 1 && getArchetype().getVarReferences().at(ccuVarRefs[i].name).isDuplicated()) - { + const auto &varRef = getArchetype().getVarReferences().at(ccuVarRefs[i].name); + if (batchSize > 1 && (varRef.getDims() & VarAccessDim::BATCH)) { // Loop through all batches and copy custom connectivity update variable references from end of row over synapse to be deleted removeSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; { @@ -325,7 +324,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Loop through any other dependent variables for (size_t i = 0; i < dependentVars.size(); i++) { // If model is batched and this dependent variable is duplicated - if (batchSize > 1 && dependentVars.at(i).isDuplicated()) { + if (batchSize > 1 && (dependentVars.at(i).getDims() & VarAccessDim::BATCH)) { // Loop through all batches and copy dependent variable from end of row over synapse to be deleted removeSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; { diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index a14b1ffee0..ae9022c9a1 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -44,7 +44,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateGroupMerged::getHashDigest() return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, EnvironmentExternalBase &env, +void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize, BackendBase::GroupHandlerEnv genPostamble) { // Add parameters, derived parameters and EGPs to environment @@ -60,17 +60,17 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", - [this, &cuEnv](const std::string&, VarAccess d) + [this, batchSize, &cuEnv](const std::string&, VarAccess d) { - return getVarIndex(d.getDims(), "$(id)"); + return getVarIndex(batchSize, clearDim(getArchetype().getDims(), d.getDims()), "$(id)"); }); // Create an environment which caches variable references in local variables if they are accessed EnvironmentLocalVarRefCache varRefEnv( *this, *this, getTypeContext(), varEnv, backend.getDeviceVarPrefix(), "", "l", - [this, &varEnv](const std::string&, const Models::VarReference &v) + [this, batchSize, &varEnv](const std::string&, const Models::VarReference &v) { - return getVarRefIndex(v.getDelayNeuronGroup() != nullptr, + return getVarRefIndex(v.getDelayNeuronGroup() != nullptr, batchSize, v.getVar().access.getDims(), "$(id)"); }); @@ -81,10 +81,10 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E genPostamble(varRefEnv, *this); } //---------------------------------------------------------------------------- -std::string CustomUpdateGroupMerged::getVarIndex(VarAccessDim varDims, const std::string &index) const +std::string CustomUpdateGroupMerged::getVarIndex(unsigned int batchSize, VarAccessDim varDims, const std::string &index) const { // **YUCK** there's a lot of duplication in these methods - do they belong elsewhere? - const bool batched = (varDims & VarAccessDim::BATCH) && getArchetype().isBatched(); + const bool batched = ((varDims & VarAccessDim::BATCH) && batchSize > 1); if (!(varDims & VarAccessDim::NEURON)) { return batched ? "$(batch)" : "0"; } @@ -98,11 +98,11 @@ std::string CustomUpdateGroupMerged::getVarIndex(VarAccessDim varDims, const std } } //---------------------------------------------------------------------------- -std::string CustomUpdateGroupMerged::getVarRefIndex(bool delay, VarAccessDim varDims, const std::string &index) const +std::string CustomUpdateGroupMerged::getVarRefIndex(bool delay, unsigned int batchSize, VarAccessDim varDims, const std::string &index) const { // If delayed, variable is shared, the batch size is one or this custom update isn't batched, batch delay offset isn't required if(delay) { - const bool batched = (varDims & VarAccessDim::BATCH) && getArchetype().isBatched(); + const bool batched = ((varDims & VarAccessDim::BATCH) && batchSize > 1); if (!(varDims & VarAccessDim::NEURON)) { return batched ? "$(_batch_delay_slot)" : "$(_delay_slot)"; } @@ -117,7 +117,7 @@ std::string CustomUpdateGroupMerged::getVarRefIndex(bool delay, VarAccessDim var } } else { - return getVarIndex(varDims, index); + return getVarIndex(batchSize, varDims, index); } } //---------------------------------------------------------------------------- @@ -171,7 +171,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWUGroupMergedBase::getHashDi return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &backend, EnvironmentExternalBase &env, +void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize, BackendBase::GroupHandlerEnv genPostamble) { // Add parameters, derived parameters and EGPs to environment @@ -187,17 +187,17 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", - [this, &cuEnv](const std::string&, VarAccess d) + [this, batchSize, &cuEnv](const std::string&, VarAccess d) { - return getVarIndex(d.getDims(), "$(id_syn)"); + return getVarIndex(batchSize, clearDim(getArchetype().getDims(), d.getDims()), "$(id_syn)"); }); // Create an environment which caches variable references in local variables if they are accessed EnvironmentLocalVarRefCache varRefEnv( *this, *this, getTypeContext(), varEnv, backend.getDeviceVarPrefix(), "", "l", - [this, &varEnv](const std::string&, const Models::WUVarReference &v) + [this, batchSize, &varEnv](const std::string&, const Models::WUVarReference &v) { - return getVarRefIndex(v.getVar().access.getDims(), "$(id_syn)"); + return getVarRefIndex(batchSize, v.getVar().access.getDims(), "$(id_syn)"); }); Transpiler::ErrorHandler errorHandler("Custom update '" + getArchetype().getName() + "' update code"); @@ -207,16 +207,17 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back genPostamble(varRefEnv, *this); } //---------------------------------------------------------------------------- -std::string CustomUpdateWUGroupMergedBase::getVarIndex(VarAccessDim varDims, const std::string &index) const +std::string CustomUpdateWUGroupMergedBase::getVarIndex(unsigned int batchSize, VarAccessDim varDims, const std::string &index) const { // **YUCK** there's a lot of duplication in these methods - do they belong elsewhere? - return (((varDims & VarAccessDim::BATCH) && getArchetype().isBatched()) ? "$(_batch_offset) + " : "") + index; + return (((varDims & VarAccessDim::BATCH) && batchSize > 1) ? "$(_batch_offset) + " : "") + index; } //---------------------------------------------------------------------------- -std::string CustomUpdateWUGroupMergedBase::getVarRefIndex(VarAccessDim varDims, const std::string &index) const +std::string CustomUpdateWUGroupMergedBase::getVarRefIndex(unsigned int batchSize, VarAccessDim varDims, const std::string &index) const { // **YUCK** there's a lot of duplication in these methods - do they belong elsewhere? - return (((varDims & VarAccessDim::BATCH) && getArchetype().isBatched()) ? "$(_batch_offset) + " : "") + index; + + return (((varDims & VarAccessDim::BATCH) && batchSize > 1) ? "$(_batch_offset) + " : "") + index; } // ---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 9c510a7063..430868c94f 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -1177,7 +1177,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, [batchSize](const CustomUpdateInternal &c, const Models::Base::Var &var) { return getVarSize(var.access.getDims(), - c.getSize(), batchSize, 1, c.isBatched()); + c.getSize(), batchSize, 1, c.getDims() & VarAccessDim::BATCH); }); genCustomUpdate( @@ -1192,7 +1192,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, ? sg->getKernelSizeFlattened() : sg->getSrcNeuronGroup()->getNumNeurons() * backend.getSynapticMatrixRowStride(*sg)); return getVarSize(var.access.getDims(), count, - batchSize, 1, c.isBatched()); + batchSize, 1, c.getDims() & VarAccessDim::BATCH); }); allVarStreams << std::endl; diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 2a98bb218d..f5fbc0280b 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -853,7 +853,7 @@ void CustomUpdateInitGroupMerged::generateInit(const BackendBase &backend, Envir { // Initialise custom update variables genInitNeuronVarCode(backend, env, *this, "", "size", 1, - getArchetype().isBatched() ? batchSize : 1); + (getArchetype().getDims() & VarAccessDim::BATCH) ? batchSize : 1); } // ---------------------------------------------------------------------------- @@ -911,7 +911,7 @@ void CustomWUUpdateInitGroupMerged::generateInit(const BackendBase &backend, Env // Loop through rows const std::string stride = kernel ? "$(_kernel_size)" : "$(num_pre) * $(_row_stride)"; genInitWUVarCode( - backend, groupEnv, *this, stride, getArchetype().isBatched() ? batchSize : 1, false, + backend, groupEnv, *this, stride, (getArchetype().getDims() & VarAccessDim::BATCH) ? batchSize : 1, false, [&backend, kernel, this](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) { if (kernel) { @@ -966,7 +966,7 @@ void CustomWUUpdateSparseInitGroupMerged::generateInit(const BackendBase &backen { genInitWUVarCode( backend, env, *this, "$(num_pre) * $(_row_stride)", - getArchetype().isBatched() ? batchSize : 1, false, + (getArchetype().getDims() & VarAccessDim::BATCH) ? batchSize : 1, false, [&backend](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) { return backend.genSparseSynapseVariableRowInit(varInitEnv, handler); diff --git a/src/genn/genn/customConnectivityUpdate.cc b/src/genn/genn/customConnectivityUpdate.cc index 8e35a1ccad..2ad273055a 100644 --- a/src/genn/genn/customConnectivityUpdate.cc +++ b/src/genn/genn/customConnectivityUpdate.cc @@ -313,7 +313,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdate::getHashDigest( { boost::uuids::detail::sha1 hash; Type::updateHash(v.getVar().type, hash); - Utils::updateHash(v.isDuplicated(), hash); + Utils::updateHash(v.getDims(), hash); return hash.get_digest(); }); @@ -329,7 +329,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdate::getHashDigest( // Update hash with duplication mode of synaptic variable references for(const auto &v : getVarReferences()) { - Utils::updateHash(v.second.isDuplicated(), hash); + Utils::updateHash(v.second.getDims(), hash); } return hash.get_digest(); diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index 9da7b2a5be..a5130ddc80 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -38,7 +38,7 @@ CustomUpdateBase::CustomUpdateBase(const std::string &name, const std::string &u : m_Name(name), m_UpdateGroupName(updateGroupName), m_CustomUpdateModel(customUpdateModel), m_Params(params), m_VarInitialisers(varInitialisers), m_EGPReferences(egpReferences), m_VarLocation(varInitialisers.size(), defaultVarLocation), m_ExtraGlobalParamLocation(customUpdateModel->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation), - m_Batched(false) + m_Dims{0} { // Validate names Utils::validatePopName(name, "Custom update"); @@ -91,13 +91,13 @@ void CustomUpdateBase::updateHash(boost::uuids::detail::sha1 &hash) const { Utils::updateHash(getCustomUpdateModel()->getHashDigest(), hash); Utils::updateHash(getUpdateGroupName(), hash); - Utils::updateHash(isBatched(), hash); + Utils::updateHash(getDims(), hash); } //---------------------------------------------------------------------------- void CustomUpdateBase::updateInitHash(boost::uuids::detail::sha1 &hash) const { Utils::updateHash(getCustomUpdateModel()->getVars(), hash); - Utils::updateHash(isBatched(), hash); + Utils::updateHash(getDims(), hash); // Include variable initialiser hashes for(const auto &w : getVarInitialisers()) { @@ -122,7 +122,7 @@ CustomUpdate::CustomUpdate(const std::string &name, const std::string &updateGro const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, const std::unordered_map &egpReferences, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) : CustomUpdateBase(name, updateGroupName, customUpdateModel, params, varInitialisers, egpReferences, defaultVarLocation, defaultExtraGlobalParamLocation), - m_VarReferences(varReferences), m_Size(varReferences.empty() ? 0 : varReferences.begin()->second.getSize()), m_DelayNeuronGroup(nullptr), m_PerNeuron(false) + m_VarReferences(varReferences), m_Size(varReferences.empty() ? 0 : varReferences.begin()->second.getSize()), m_DelayNeuronGroup(nullptr) { // Validate parameters, variables and variable references getCustomUpdateModel()->validate(getParams(), getVarInitialisers(), getVarReferences(), "Custom update " + getName()); @@ -135,7 +135,7 @@ CustomUpdate::CustomUpdate(const std::string &name, const std::string &updateGro Models::checkVarReferenceTypes(m_VarReferences, getCustomUpdateModel()->getVarRefs()); // Update is per-neuron if any variables or variable reference targets have neuron dimension - const auto modelVars = getCustomUpdateModel()->getVars(); + /*const auto modelVars = getCustomUpdateModel()->getVars(); m_PerNeuron = std::any_of(m_VarReferences.cbegin(), m_VarReferences.cend(), [](const auto& v) { @@ -159,7 +159,7 @@ CustomUpdate::CustomUpdate(const std::string &name, const std::string &updateGro throw std::runtime_error("Variable references to SHARED_NEURON variables in per-neuron custom updates cannot be read-write."); } } - + */ // Check only one type of reduction is specified if (isBatchReduction() && isNeuronReduction()) { throw std::runtime_error("Custom updates cannot perform batch and neuron reductions simultaneously."); @@ -179,7 +179,7 @@ void CustomUpdate::finalise(double dt, unsigned int batchSize) CustomUpdateBase::finalise(dt); // Check variable reference batching - checkVarReferenceBatching(m_VarReferences, batchSize); + checkVarReferenceDims(m_VarReferences, batchSize); // If any variable references have delays auto delayRef = std::find_if(m_VarReferences.cbegin(), m_VarReferences.cend(), @@ -203,8 +203,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdate::getHashDigest() const boost::uuids::detail::sha1 hash; CustomUpdateBase::updateHash(hash); - // Update hash with whether custom update is per-neuron and if delay is required - Utils::updateHash(isPerNeuron(), hash); + // Update hash with whether delay is required const bool delayed = (getDelayNeuronGroup() != nullptr); Utils::updateHash(delayed, hash); @@ -229,7 +228,6 @@ boost::uuids::detail::sha1::digest_type CustomUpdate::getInitHashDigest() const // Superclass boost::uuids::detail::sha1 hash; CustomUpdateBase::updateInitHash(hash); - Utils::updateHash(isPerNeuron(), hash); return hash.get_digest(); } @@ -263,20 +261,6 @@ CustomUpdateWU::CustomUpdateWU(const std::string &name, const std::string &updat { throw std::runtime_error("All referenced variables must belong to the same synapse group."); } - - // Give error if custom update model includes any shared neuron variables - // **NOTE** because there's no way to reference neuron variables with WUVarReferences, - // this safely checks for attempts to do neuron reductions - /*const auto vars = getCustomUpdateModel()->getVars(); - if (std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v) - { - return (v.getAccess(VarAccess::READ_WRITE) & VarAccessDuplication::SHARED_NEURON); - })) - { - throw std::runtime_error("Custom weight updates cannot use models with SHARED_NEURON variables."); - }*/ - // If this is a transpose operation if(isTransposeOperation()) { // Check that it isn't also a reduction @@ -311,7 +295,7 @@ void CustomUpdateWU::finalise(double dt, unsigned int batchSize) CustomUpdateBase::finalise(dt); // Check variable reference types - checkVarReferenceBatching(m_VarReferences, batchSize); + checkVarReferenceDims(m_VarReferences, batchSize); } //---------------------------------------------------------------------------- bool CustomUpdateWU::isTransposeOperation() const diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index 1029b238c1..a4a278d6e9 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -73,21 +73,16 @@ std::string VarReference::getTargetName() const m_Detail); } //---------------------------------------------------------------------------- -bool VarReference::isDuplicated() const +VarAccessDim VarReference::getDims() const { - // If target variable has BATCH dimension - if(getVar().access.getDims() & VarAccessDim::BATCH) { - return std::visit( - Utils::Overload{ - [](const CURef &ref) { return ref.group->isBatched(); }, - [](const CCUPreRef&){ return false; }, - [](const CCUPostRef&){ return false; }, - [](const auto&) { return true; }}, - m_Detail); - } - else { - return false; - } + const VarAccessDim varDims = getVar().access.getDims(); + return std::visit( + Utils::Overload{ + [varDims](const CURef &ref) { return clearDim(ref.group->getDims(), varDims); }, + [varDims](const CCUPreRef&){ return clearDim(varDims, VarAccessDim::BATCH); }, + [varDims](const CCUPostRef&){ return clearDim(varDims, VarAccessDim::BATCH); }, + [varDims](const auto&) { return varDims; }}, + m_Detail); } //---------------------------------------------------------------------------- CustomUpdate *VarReference::getReferencedCustomUpdate() const @@ -174,20 +169,15 @@ std::string WUVarReference::getTargetName() const m_Detail); } //---------------------------------------------------------------------------- -bool WUVarReference::isDuplicated() const +VarAccessDim WUVarReference::getDims() const { - // If target variable has BATCH dimension - if(getVar().access.getDims() & VarAccessDim::BATCH) { - return std::visit( - Utils::Overload{ - [](const CURef &ref) { return ref.group->isBatched(); }, - [](const CCURef&) { return false; }, - [](const WURef&) { return true; }}, - m_Detail); - } - else { - return false; - } + const VarAccessDim varDims = getVar().access.getDims(); + return std::visit( + Utils::Overload{ + [varDims](const CURef &ref) { return clearDim(ref.group->getDims(), varDims); }, + [varDims](const CCURef&) { return clearDim(varDims, VarAccessDim::BATCH); }, + [varDims](const WURef&) { return varDims; }}, + m_Detail); } //---------------------------------------------------------------------------- SynapseGroup *WUVarReference::getSynapseGroup() const From e38d861ca79e0cc0e8ea3451b494b90c0c837521 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 29 Aug 2023 10:15:43 +0100 Subject: [PATCH 14/60] fixed unit tests --- tests/unit/customConnectivityUpdate.cc | 2 +- tests/unit/customUpdate.cc | 76 +++++++++++++------------- tests/unit/neuronGroup.cc | 2 +- 3 files changed, 40 insertions(+), 40 deletions(-) diff --git a/tests/unit/customConnectivityUpdate.cc b/tests/unit/customConnectivityUpdate.cc index 4e8ba53e98..61490f413a 100644 --- a/tests/unit/customConnectivityUpdate.cc +++ b/tests/unit/customConnectivityUpdate.cc @@ -19,7 +19,7 @@ class StaticPulseDendriticDelayReverse : public WeightUpdateModels::Base public: DECLARE_SNIPPET(StaticPulseDendriticDelayReverse); - SET_VARS({{"d", "uint8_t", VarAccess::READ_ONLY}, {"g", "scalar", VarAccess::READ_ONLY}}); + SET_VARS({{"d", "uint8_t", SynapseVarAccess::READ_ONLY}, {"g", "scalar", SynapseVarAccess::READ_ONLY}}); SET_SIM_CODE("addToPostDelay(g, d);\n"); }; diff --git a/tests/unit/customUpdate.cc b/tests/unit/customUpdate.cc index fd303618c1..ce9d6a7c77 100644 --- a/tests/unit/customUpdate.cc +++ b/tests/unit/customUpdate.cc @@ -24,8 +24,8 @@ class IzhikevichVariableShared : public NeuronModels::Izhikevich SET_PARAM_NAMES({}); SET_VARS({{"V","scalar"}, {"U", "scalar"}, - {"a", "scalar", VarAccess::READ_ONLY_SHARED_NEURON}, {"b", "scalar", VarAccess::READ_ONLY_SHARED_NEURON}, - {"c", "scalar", VarAccess::READ_ONLY_SHARED_NEURON}, {"d", "scalar", VarAccess::READ_ONLY_SHARED_NEURON}}); + {"a", "scalar", NeuronVarAccess::READ_ONLY_SHARED_NEURON}, {"b", "scalar", NeuronVarAccess::READ_ONLY_SHARED_NEURON}, + {"c", "scalar", NeuronVarAccess::READ_ONLY_SHARED_NEURON}, {"d", "scalar", NeuronVarAccess::READ_ONLY_SHARED_NEURON}}); }; IMPLEMENT_SNIPPET(IzhikevichVariableShared); @@ -34,10 +34,10 @@ class StaticPulseDendriticDelaySplit : public WeightUpdateModels::Base public: DECLARE_SNIPPET(StaticPulseDendriticDelaySplit); - SET_VARS({{"gCommon", "scalar", VarAccess::READ_ONLY}, - {"g", "scalar", VarAccess::READ_ONLY_DUPLICATE}, - {"dCommon", "scalar", VarAccess::READ_ONLY}, - {"d", "scalar", VarAccess::READ_ONLY_DUPLICATE}}); + SET_VARS({{"gCommon", "scalar", SynapseVarAccess::READ_ONLY}, + {"g", "scalar", SynapseVarAccess::READ_ONLY_DUPLICATE}, + {"dCommon", "scalar", SynapseVarAccess::READ_ONLY}, + {"d", "scalar", SynapseVarAccess::READ_ONLY_DUPLICATE}}); SET_SIM_CODE("$(addToInSynDelay, $(gCommon) + $(g), $(dCommon) + $(d));\n"); }; @@ -61,7 +61,7 @@ class Sum2 : public CustomUpdateModels::Base SET_UPDATE_CODE("$(a) = $(mult) * ($(a) + $(b));\n"); - SET_VARS({{"mult", "scalar", VarAccess::READ_ONLY}}); + SET_VARS({{"mult", "scalar", CustomUpdateVarAccess::READ_ONLY}}); SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_WRITE}, {"b", "scalar", VarAccessMode::READ_ONLY}}); }; @@ -73,7 +73,7 @@ class Sum3 : public CustomUpdateModels::Base SET_UPDATE_CODE("$(sum) = $(scale) * ($(a) + $(b));\n"); - SET_VARS({{"sum", "scalar"}, {"scale", "scalar", VarAccess::READ_ONLY_SHARED_NEURON}}); + SET_VARS({{"sum", "scalar"}, {"scale", "scalar", CustomUpdateVarAccess::READ_ONLY_SHARED_NEURON}}); SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_WRITE}, {"b", "scalar", VarAccessMode::READ_ONLY}}); }; @@ -169,8 +169,8 @@ class ReduceDouble : public CustomUpdateModels::Base "reduction1 = var1;\n" "reduction2 = var2;\n"); - SET_VARS({{"reduction1", "scalar", VarAccess::REDUCE_BATCH_SUM}, - {"reduction2", "scalar", VarAccess::REDUCE_NEURON_SUM}}); + SET_VARS({{"reduction1", "scalar", CustomUpdateVarAccess::REDUCE_BATCH_SUM}, + {"reduction2", "scalar", CustomUpdateVarAccess::REDUCE_NEURON_SUM}}); SET_VAR_REFS({{"var1", "scalar", VarAccessMode::READ_ONLY}, {"var2", "scalar", VarAccessMode::READ_ONLY}}); @@ -183,7 +183,7 @@ class ReduceSharedVar : public CustomUpdateModels::Base SET_UPDATE_CODE("reduction = var;\n"); - SET_VARS({{"reduction", "scalar", VarAccess::REDUCE_BATCH_SUM}}) + SET_VARS({{"reduction", "scalar", CustomUpdateVarAccess::REDUCE_BATCH_SUM}}) SET_VAR_REFS({{"var", "scalar", VarAccessMode::READ_ONLY}}); }; IMPLEMENT_SNIPPET(ReduceSharedVar); @@ -195,7 +195,7 @@ class ReduceNeuronSharedVar : public CustomUpdateModels::Base SET_UPDATE_CODE("reduction = var;\n"); - SET_VARS({{"reduction", "scalar", VarAccess::REDUCE_NEURON_SUM}}) + SET_VARS({{"reduction", "scalar", CustomUpdateVarAccess::REDUCE_NEURON_SUM}}) SET_VAR_REFS({{"var", "scalar", VarAccessMode::READ_ONLY}}); }; IMPLEMENT_SNIPPET(ReduceNeuronSharedVar); @@ -509,14 +509,14 @@ TEST(CustomUpdates, BatchingVars) model.finalise(); - EXPECT_TRUE(static_cast(sum1)->isBatched()); - EXPECT_TRUE(static_cast(sum1)->isPerNeuron()); - EXPECT_FALSE(static_cast(sum2)->isBatched()); - EXPECT_TRUE(static_cast(sum2)->isPerNeuron()); - EXPECT_TRUE(static_cast(sum3)->isBatched()); - EXPECT_TRUE(static_cast(sum3)->isPerNeuron()); - EXPECT_FALSE(static_cast(sum4)->isBatched()); - EXPECT_TRUE(static_cast(sum4)->isPerNeuron()); + EXPECT_TRUE(static_cast(sum1)->getDims() & VarAccessDim::BATCH); + EXPECT_TRUE(static_cast(sum1)->getDims() & VarAccessDim::NEURON); + EXPECT_FALSE(static_cast(sum2)->getDims() & VarAccessDim::BATCH); + EXPECT_TRUE(static_cast(sum2)->getDims() & VarAccessDim::NEURON); + EXPECT_TRUE(static_cast(sum3)->getDims() & VarAccessDim::BATCH); + EXPECT_TRUE(static_cast(sum3)->getDims() & VarAccessDim::NEURON); + EXPECT_FALSE(static_cast(sum4)->getDims() & VarAccessDim::BATCH); + EXPECT_TRUE(static_cast(sum4)->getDims() & VarAccessDim::NEURON); } //-------------------------------------------------------------------------- TEST(CustomUpdates, NeuronSharedVars) @@ -534,8 +534,8 @@ TEST(CustomUpdates, NeuronSharedVars) model.finalise(); auto *cuInternal = static_cast(cu); - EXPECT_TRUE(cuInternal->isBatched()); - EXPECT_FALSE(cuInternal->isPerNeuron()); + EXPECT_TRUE(cuInternal->getDims() & VarAccessDim::BATCH); + EXPECT_FALSE(cuInternal->getDims() & VarAccessDim::NEURON); } //-------------------------------------------------------------------------- TEST(CustomUpdates, BatchingWriteShared) @@ -617,10 +617,10 @@ TEST(CustomUpdates, ReductionTypeDuplicateNeuron) {}, {}, reduceVarReferences); model.finalise(); auto *cuInternal = static_cast(cu); - ASSERT_TRUE(cuInternal->isBatched()); + ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::BATCH); ASSERT_FALSE(cuInternal->isBatchReduction()); ASSERT_TRUE(cuInternal->isNeuronReduction()); - ASSERT_TRUE(cuInternal->isPerNeuron()); + ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::NEURON); } //-------------------------------------------------------------------------- TEST(CustomUpdates, ReductionTypeDuplicateNeuronInternal) @@ -640,10 +640,10 @@ TEST(CustomUpdates, ReductionTypeDuplicateNeuronInternal) {}, reduceVars, reduceVarReferences); model.finalise(); auto *cuInternal = static_cast(cu); - ASSERT_TRUE(cuInternal->isBatched()); + ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::BATCH); ASSERT_FALSE(cuInternal->isBatchReduction()); ASSERT_TRUE(cuInternal->isNeuronReduction()); - ASSERT_TRUE(cuInternal->isPerNeuron()); + ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::NEURON); } //-------------------------------------------------------------------------- TEST(CustomUpdates, ReductionTypeSharedNeuronInternal) @@ -663,10 +663,10 @@ TEST(CustomUpdates, ReductionTypeSharedNeuronInternal) {}, reduceVars, reduceVarReferences); model.finalise(); auto *cuInternal = static_cast(cu); - ASSERT_FALSE(cuInternal->isBatched()); + ASSERT_FALSE(cuInternal->getDims() & VarAccessDim::BATCH); ASSERT_FALSE(cuInternal->isBatchReduction()); ASSERT_TRUE(cuInternal->isNeuronReduction()); - ASSERT_TRUE(cuInternal->isPerNeuron()); + ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::NEURON); } //-------------------------------------------------------------------------- TEST(CustomUpdates, ReductionTypeDuplicateBatch) @@ -685,10 +685,10 @@ TEST(CustomUpdates, ReductionTypeDuplicateBatch) {}, {}, reduceVarReferences); model.finalise(); auto *cuInternal = static_cast(cu); - ASSERT_TRUE(cuInternal->isBatched()); + ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::BATCH); ASSERT_TRUE(cuInternal->isBatchReduction()); ASSERT_FALSE(cuInternal->isNeuronReduction()); - ASSERT_TRUE(cuInternal->isPerNeuron()); + ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::NEURON); } //-------------------------------------------------------------------------- TEST(CustomUpdates, ReductionTypeDuplicateBatchInternal) @@ -708,10 +708,10 @@ TEST(CustomUpdates, ReductionTypeDuplicateBatchInternal) {}, reduceVars, reduceVarReferences); model.finalise(); auto *cuInternal = static_cast(cu); - ASSERT_TRUE(cuInternal->isBatched()); + ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::BATCH); ASSERT_TRUE(cuInternal->isBatchReduction()); ASSERT_FALSE(cuInternal->isNeuronReduction()); - ASSERT_TRUE(cuInternal->isPerNeuron()); + ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::NEURON); } //-------------------------------------------------------------------------- TEST(CustomUpdates, NeuronSharedCustomUpdateWU) @@ -949,9 +949,9 @@ TEST(CustomUpdates, CompareDifferentBatched) CustomUpdateInternal *sum1Internal = static_cast(sum1); CustomUpdateInternal *sum2Internal = static_cast(sum2); CustomUpdateInternal *sum3Internal = static_cast(sum3); - ASSERT_TRUE(sum1Internal->isBatched()); - ASSERT_FALSE(sum2Internal->isBatched()); - ASSERT_TRUE(sum3Internal->isBatched()); + ASSERT_TRUE(sum1Internal->getDims() & VarAccessDim::BATCH); + ASSERT_FALSE(sum2Internal->getDims() & VarAccessDim::BATCH); + ASSERT_TRUE(sum3Internal->getDims() & VarAccessDim::BATCH); // Check that neither initialisation nor update of batched and unbatched can be merged ASSERT_NE(sum1Internal->getHashDigest(), sum2Internal->getHashDigest()); @@ -1099,9 +1099,9 @@ TEST(CustomUpdates, CompareDifferentWUBatched) CustomUpdateWUInternal *sum1Internal = static_cast(sum1); CustomUpdateWUInternal *sum2Internal = static_cast(sum2); CustomUpdateWUInternal *sum3Internal = static_cast(sum3); - ASSERT_TRUE(sum1Internal->isBatched()); - ASSERT_FALSE(sum2Internal->isBatched()); - ASSERT_TRUE(sum3Internal->isBatched()); + ASSERT_TRUE(sum1Internal->getDims() & VarAccessDim::BATCH); + ASSERT_FALSE(sum2Internal->getDims() & VarAccessDim::BATCH); + ASSERT_TRUE(sum3Internal->getDims() & VarAccessDim::BATCH); // Check that neither initialisation nor update of batched and unbatched can be merged ASSERT_NE(sum1Internal->getHashDigest(), sum2Internal->getHashDigest()); diff --git a/tests/unit/neuronGroup.cc b/tests/unit/neuronGroup.cc index dc9f4c2c9c..7d621dbb2d 100644 --- a/tests/unit/neuronGroup.cc +++ b/tests/unit/neuronGroup.cc @@ -19,7 +19,7 @@ class StaticPulseBack : public WeightUpdateModels::Base public: DECLARE_SNIPPET(StaticPulseBack); - SET_VARS({{"g", "scalar", VarAccess::READ_ONLY}}); + SET_VARS({{"g", "scalar", SynapseVarAccess::READ_ONLY}}); SET_SIM_CODE( "$(addToInSyn, $(g));\n" From 492835f234d74e34dafb2c3045cebcc1ac6e85a9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 29 Aug 2023 10:31:37 +0100 Subject: [PATCH 15/60] removed GENN_EXPORT --- include/genn/genn/varAccess.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/genn/genn/varAccess.h b/include/genn/genn/varAccess.h index 649b9dc2df..47715f8559 100644 --- a/include/genn/genn/varAccess.h +++ b/include/genn/genn/varAccess.h @@ -120,7 +120,7 @@ inline VarAccessDim clearDim(VarAccessDim a, VarAccessDim b) // VarAccess //---------------------------------------------------------------------------- //! Wrapper class encapsulating -GENN_EXPORT class VarAccess +class VarAccess { public: VarAccess() From 524e014b367ef60d308e153f86ab51bc035b5246 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 29 Aug 2023 10:46:01 +0100 Subject: [PATCH 16/60] updated PyBind11 wrapper --- pygenn/src/genn.cc | 79 +++++++++++++++++++++++++++++++--------------- 1 file changed, 54 insertions(+), 25 deletions(-) diff --git a/pygenn/src/genn.cc b/pygenn/src/genn.cc index 8be57ca6d8..c385c9d6a8 100644 --- a/pygenn/src/genn.cc +++ b/pygenn/src/genn.cc @@ -285,31 +285,45 @@ PYBIND11_MODULE(genn, m) .def("__and__", [](VarAccessMode a, VarAccessModeAttribute b){ return a & b; }, pybind11::is_operator()); - - //! Flags defining how variables should be duplicated across multiple batches - pybind11::enum_(m, "VarAccessDuplication") - .value("DUPLICATE", VarAccessDuplication::DUPLICATE) - .value("SHARED", VarAccessDuplication::SHARED) - .value("SHARED_NEURON", VarAccessDuplication::SHARED_NEURON); - - //! Supported combinations of VarAccessMode and VarAccessDuplication - pybind11::enum_(m, "VarAccess") - .value("READ_WRITE", VarAccess::READ_WRITE) - .value("READ_ONLY", VarAccess::READ_ONLY) - .value("READ_ONLY_SHARED_NEURON", VarAccess::READ_ONLY_SHARED_NEURON) - .value("READ_ONLY_DUPLICATE", VarAccess::READ_ONLY_DUPLICATE) - .value("REDUCE_BATCH_SUM", VarAccess::REDUCE_BATCH_SUM) - .value("REDUCE_BATCH_MAX", VarAccess::REDUCE_BATCH_MAX) - .value("REDUCE_NEURON_SUM", VarAccess::REDUCE_NEURON_SUM) - .value("REDUCE_NEURON_MAX", VarAccess::REDUCE_NEURON_MAX) - - .def("__and__", [](VarAccess a, VarAccessModeAttribute b){ return a & b; }, - pybind11::is_operator()) - .def("__and__", [](VarAccess a, VarAccessMode b){ return a & b; }, - pybind11::is_operator()) - .def("__and__", [](VarAccess a, VarAccessDuplication b){ return a & b; }, + + //! Flags defining dimensions this variables has + pybind11::enum_(m, "VarAccessDim") + .value("NEURON", VarAccessDim::NEURON) + .value("PRE_NEURON", VarAccessDim::PRE_NEURON) + .value("POST_NEURON", VarAccessDim::POST_NEURON) + .value("BATCH", VarAccessDim::BATCH) + + .def("__and__", [](VarAccessDim a, VarAccessDim b){ return a & b; }, pybind11::is_operator()); + //! Supported combinations of access mode and dimension for neuron variables + pybind11::enum_(m, "NeuronVarAccess") + .value("READ_WRITE", NeuronVarAccess::READ_WRITE) + .value("READ_ONLY", NeuronVarAccess::READ_ONLY) + .value("READ_ONLY_DUPLICATE", NeuronVarAccess::READ_ONLY_DUPLICATE) + .value("READ_ONLY_SHARED_NEURON", NeuronVarAccess::READ_ONLY_SHARED_NEURON); + + //! Supported combinations of access mode and dimension for synapse variables + pybind11::enum_(m, "SynapseVarAccess") + .value("READ_WRITE", SynapseVarAccess::READ_WRITE) + .value("READ_ONLY", SynapseVarAccess::READ_ONLY) + .value("READ_ONLY_DUPLICATE", SynapseVarAccess::READ_ONLY_DUPLICATE); + + + //! Supported combinations of access mode and dimension for custom update variables + /*! The axes are defined 'subtractively' ie VarAccessDim::BATCH indicates that this axis should be removed */ + pybind11::enum_(m, "CustomUpdateVarAccess") + .value("READ_WRITE", CustomUpdateVarAccess::READ_WRITE) + .value("READ_ONLY", CustomUpdateVarAccess::READ_ONLY) + .value("READ_WRITE_SHARED", CustomUpdateVarAccess::READ_WRITE_SHARED) + .value("READ_ONLY_SHARED", CustomUpdateVarAccess::READ_ONLY_SHARED) + .value("READ_WRITE_SHARED_NEURON", CustomUpdateVarAccess::READ_WRITE_SHARED_NEURON) + .value("READ_ONLY_SHARED_NEURON", CustomUpdateVarAccess::READ_ONLY_SHARED_NEURON) + .value("REDUCE_BATCH_SUM", CustomUpdateVarAccess::REDUCE_BATCH_SUM) + .value("REDUCE_BATCH_MAX", CustomUpdateVarAccess::REDUCE_BATCH_MAX) + .value("REDUCE_NEURON_SUM", CustomUpdateVarAccess::REDUCE_NEURON_SUM) + .value("REDUCE_NEURON_MAX", CustomUpdateVarAccess::REDUCE_NEURON_MAX); + //! Locations of variables pybind11::enum_(m, "VarLocation") .value("HOST", VarLocation::HOST) @@ -498,7 +512,7 @@ PYBIND11_MODULE(genn, m) .def_property_readonly("var_references", &CustomUpdate::getVarReferences) // **NOTE** we use the 'publicist' pattern to expose some protected properties - .def_property_readonly("_is_batched", &CustomUpdateInternal::isBatched); + .def_property_readonly("_dims", &CustomUpdateInternal::getDims); //------------------------------------------------------------------------ @@ -508,7 +522,7 @@ PYBIND11_MODULE(genn, m) .def_property_readonly("var_references", &CustomUpdateWU::getVarReferences) // **NOTE** we use the 'publicist' pattern to expose some protected properties - .def_property_readonly("_is_batched", &CustomUpdateWUInternal::isBatched); + .def_property_readonly("_dims", &CustomUpdateWUInternal::getDims); //------------------------------------------------------------------------ // genn.NeuronGroup @@ -678,6 +692,21 @@ PYBIND11_MODULE(genn, m) .def("get_code", &InitVarSnippet::Base::getCode); + //------------------------------------------------------------------------ + // genn.VarAccess + //------------------------------------------------------------------------ + pybind11::class_(m, "VarAccess") + .def(pybind11::init()) + .def(pybind11::init()) + .def(pybind11::init()) + + .def("get_neuron_dims", + [](const VarAccess &v) { return v.getDims(); }) + .def("get_synapse_dims", + [](const VarAccess &v) { return v.getDims(); }) + .def("get_custom_update_dims", + [](const VarAccess &v) { return v.getDims(); }); + //------------------------------------------------------------------------ // genn.Var //------------------------------------------------------------------------ From 25e9dce0c46f41d5085035c9fe085a2961bc2ef9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 29 Aug 2023 10:46:14 +0100 Subject: [PATCH 17/60] at least fixed PyGeNN interface --- pygenn/__init__.py | 5 +++-- pygenn/genn_groups.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pygenn/__init__.py b/pygenn/__init__.py index 552abe7982..9653790ef6 100644 --- a/pygenn/__init__.py +++ b/pygenn/__init__.py @@ -4,8 +4,9 @@ # pygenn interface from .genn import (create_var_ref, create_psm_var_ref, create_wu_pre_var_ref, create_wu_post_var_ref, create_wu_var_ref, create_egp_ref, - create_psm_egp_ref, create_wu_egp_ref, PlogSeverity, - SpanType, SynapseMatrixType, VarAccess, + create_psm_egp_ref, create_wu_egp_ref, + CustomUpdateVarAccess, NeuronVarAccess, PlogSeverity, + SpanType, SynapseMatrixType, SynapseVarAccess, VarAccessMode, VarLocation) from .genn_model import (GeNNModel, create_neuron_model, create_postsynaptic_model, diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index 86293bf982..5b6447d19c 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -15,7 +15,7 @@ from . import neuron_models, types from .genn import (CustomUpdateWU, SynapseMatrixConnectivity, - SynapseMatrixWeight, VarAccessDuplication, VarLocation) + SynapseMatrixWeight, VarAccessDim, VarLocation) from .model_preprocessor import prepare_model, ExtraGlobalParameter, Variable From a304015215f9825e4caba5cf965a4c9b85a86d37 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 29 Aug 2023 10:46:20 +0100 Subject: [PATCH 18/60] update feature tests --- .../test_custom_connectivity_update.py | 6 +-- tests/features/test_custom_update.py | 39 ++++++++++++------- tests/features/test_spike_propagation.py | 8 ++-- tests/features/test_spike_times.py | 1 - tests/features/test_wu_vars.py | 10 ++--- 5 files changed, 36 insertions(+), 28 deletions(-) diff --git a/tests/features/test_custom_connectivity_update.py b/tests/features/test_custom_connectivity_update.py index 8fe59d1c17..99d4a36717 100644 --- a/tests/features/test_custom_connectivity_update.py +++ b/tests/features/test_custom_connectivity_update.py @@ -3,7 +3,7 @@ from pygenn import types from pygenn import GeNNModel -from pygenn.genn import VarAccess, VarAccessMode +from pygenn.genn import SynapseVarAccess, VarAccessMode from bitarray import bitarray from bitarray.util import hex2ba @@ -26,8 +26,8 @@ weight_update_model = create_weight_update_model( "weight_update", - var_name_types=[("g", "scalar", VarAccess.READ_ONLY_DUPLICATE), - ("d", "unsigned int", VarAccess.READ_ONLY)]) + var_name_types=[("g", "scalar", SynapseVarAccess.READ_ONLY_DUPLICATE), + ("d", "unsigned int", SynapseVarAccess.READ_ONLY)]) # Snippet to initialise variable to hold its column-major index diff --git a/tests/features/test_custom_update.py b/tests/features/test_custom_update.py index 188ed9d672..504620987f 100644 --- a/tests/features/test_custom_update.py +++ b/tests/features/test_custom_update.py @@ -3,7 +3,8 @@ from pygenn import types from pygenn import GeNNModel -from pygenn.genn import VarAccess, VarAccessMode +from pygenn.genn import (CustomUpdateVarAccess, NeuronVarAccess, + SynapseVarAccess, VarAccessMode) from scipy.special import softmax from pygenn import (create_current_source_model, @@ -27,25 +28,30 @@ def test_custom_update(backend, precision, batch_size): neuron_model = create_neuron_model( "neuron", - var_name_types=[("X", "scalar", VarAccess.READ_ONLY_DUPLICATE), ("XShared", "scalar", VarAccess.READ_ONLY_SHARED_NEURON)]) + var_name_types=[("X", "scalar", NeuronVarAccess.READ_ONLY_DUPLICATE), + ("XShared", "scalar", NeuronVarAccess.READ_ONLY_SHARED_NEURON)]) current_source_model = create_current_source_model( "current_source", - var_name_types=[("X", "scalar", VarAccess.READ_ONLY_DUPLICATE), ("XShared", "scalar", VarAccess.READ_ONLY_SHARED_NEURON)]) + var_name_types=[("X", "scalar", NeuronVarAccess.READ_ONLY_DUPLICATE), + ("XShared", "scalar", NeuronVarAccess.READ_ONLY_SHARED_NEURON)]) weight_update_model = create_weight_update_model( "weight_update", - var_name_types=[("X", "scalar", VarAccess.READ_ONLY_DUPLICATE)], - pre_var_name_types=[("preX", "scalar", VarAccess.READ_ONLY_DUPLICATE), ("preXShared", "scalar", VarAccess.READ_ONLY_SHARED_NEURON)], - post_var_name_types=[("postX", "scalar", VarAccess.READ_ONLY_DUPLICATE), ("postXShared", "scalar", VarAccess.READ_ONLY_SHARED_NEURON)]) + var_name_types=[("X", "scalar", SynapseVarAccess.READ_ONLY_DUPLICATE)], + pre_var_name_types=[("preX", "scalar", NeuronVarAccess.READ_ONLY_DUPLICATE), + ("preXShared", "scalar", NeuronVarAccess.READ_ONLY_SHARED_NEURON)], + post_var_name_types=[("postX", "scalar", NeuronVarAccess.READ_ONLY_DUPLICATE), + ("postXShared", "scalar", NeuronVarAccess.READ_ONLY_SHARED_NEURON)]) postsynaptic_update_model = create_postsynaptic_model( "postsynaptic_update", - var_name_types=[("psmX", "scalar", VarAccess.READ_ONLY_DUPLICATE), ("psmXShared", "scalar", VarAccess.READ_ONLY_SHARED_NEURON)]) + var_name_types=[("psmX", "scalar", NeuronVarAccess.READ_ONLY_DUPLICATE), + ("psmXShared", "scalar", NeuronVarAccess.READ_ONLY_SHARED_NEURON)]) custom_update_model = create_custom_update_model( "custom_update", - var_name_types=[("X", "scalar", VarAccess.READ_ONLY_DUPLICATE)], + var_name_types=[("X", "scalar", CustomUpdateVarAccess.READ_ONLY_DUPLICATE)], var_refs=[("R", "scalar")]) set_time_custom_update_model = create_custom_update_model( @@ -210,7 +216,7 @@ def test_custom_update(backend, precision, batch_size): def test_custom_update_transpose(backend, precision, batch_size): static_pulse_duplicate_model = create_weight_update_model( "static_pulse_duplicate", - var_name_types=[("g", "scalar", VarAccess.READ_ONLY_DUPLICATE)], + var_name_types=[("g", "scalar", SynapseVarAccess.READ_ONLY_DUPLICATE)], sim_code= """ addToPost(g); @@ -266,7 +272,8 @@ def test_custom_update_transpose(backend, precision, batch_size): def test_custom_update_neuron_reduce(backend, precision, batch_size): reduction_neuron_model = create_neuron_model( "reduction_neuron", - var_name_types=[("X", "scalar", VarAccess.READ_ONLY_DUPLICATE), ("Y", "scalar", VarAccess.READ_ONLY_DUPLICATE)]) + var_name_types=[("X", "scalar", CustomUpdateVarAccess.READ_ONLY_DUPLICATE), + ("Y", "scalar", CustomUpdateVarAccess.READ_ONLY_DUPLICATE)]) softmax_1_custom_update_model = create_custom_update_model( "softmax_1", @@ -274,7 +281,7 @@ def test_custom_update_neuron_reduce(backend, precision, batch_size): """ MaxX = X; """, - var_name_types=[("MaxX", "scalar", VarAccess.REDUCE_NEURON_MAX)], + var_name_types=[("MaxX", "scalar", CustomUpdateVarAccess.REDUCE_NEURON_MAX)], var_refs=[("X", "scalar", VarAccessMode.READ_ONLY)]) softmax_2_custom_update_model = create_custom_update_model( @@ -283,7 +290,7 @@ def test_custom_update_neuron_reduce(backend, precision, batch_size): """ SumExpX = exp(X - MaxX); """, - var_name_types=[("SumExpX", "scalar", VarAccess.REDUCE_NEURON_SUM)], + var_name_types=[("SumExpX", "scalar", CustomUpdateVarAccess.REDUCE_NEURON_SUM)], var_refs=[("X", "scalar", VarAccessMode.READ_ONLY), ("MaxX", "scalar", VarAccessMode.READ_ONLY)]) @@ -346,11 +353,13 @@ def test_custom_update_batch_reduction(backend, precision, batch_size): # **TODO** once VarAccess is refactored, we should really be able to reduce neuron shared across batch dimension neuron_model = create_neuron_model( "neuron", - var_name_types=[("X", "scalar", VarAccess.READ_ONLY_DUPLICATE), ("SumX", "scalar", VarAccess.READ_ONLY)]) + var_name_types=[("X", "scalar", NeuronVarAccess.READ_ONLY_DUPLICATE), + ("SumX", "scalar", NeuronVarAccess.READ_ONLY)]) weight_update_model = create_weight_update_model( "weight_update", - var_name_types=[("X", "scalar", VarAccess.READ_ONLY_DUPLICATE), ("SumX", "scalar", VarAccess.READ_ONLY)]) + var_name_types=[("X", "scalar", SynapseVarAccess.READ_ONLY_DUPLICATE), + ("SumX", "scalar", SynapseVarAccess.READ_ONLY)]) reduction_custom_update_model = create_custom_update_model( "reduction_custom_update", @@ -359,7 +368,7 @@ def test_custom_update_batch_reduction(backend, precision, batch_size): SumX = X; MaxX = X; """, - var_name_types=[("MaxX", "scalar", VarAccess.REDUCE_BATCH_MAX)], + var_name_types=[("MaxX", "scalar", CustomUpdateVarAccess.REDUCE_BATCH_MAX)], var_refs=[("X", "scalar", VarAccessMode.READ_ONLY), ("SumX", "scalar", VarAccessMode.REDUCE_SUM)]) diff --git a/tests/features/test_spike_propagation.py b/tests/features/test_spike_propagation.py index 5ab728e458..7e8698ae4a 100644 --- a/tests/features/test_spike_propagation.py +++ b/tests/features/test_spike_propagation.py @@ -4,7 +4,7 @@ from pygenn import GeNNModel -from pygenn.genn import SpanType, VarAccess +from pygenn.genn import NeuronVarAccess, SpanType, SynapseVarAccess from pygenn import (create_neuron_model, create_sparse_connect_init_snippet, create_var_init_snippet, @@ -511,7 +511,7 @@ def test_reverse(backend, precision): pre_reverse_spike_source_model = create_neuron_model( "pre_reverse_spike_source", var_name_types=[("startSpike", "unsigned int"), - ("endSpike", "unsigned int", VarAccess.READ_ONLY_DUPLICATE), + ("endSpike", "unsigned int", NeuronVarAccess.READ_ONLY_DUPLICATE), ("x", "scalar")], extra_global_params=[("spikeTimes", "scalar*")], sim_code= @@ -533,7 +533,7 @@ def test_reverse(backend, precision): """ $(addToPre, $(g)); """, - var_name_types=[("g", "scalar", VarAccess.READ_ONLY)]) + var_name_types=[("g", "scalar", SynapseVarAccess.READ_ONLY)]) model = GeNNModel(precision, "test_reverse", backend=backend) model.dt = 1.0 @@ -617,7 +617,7 @@ def test_reverse_post(backend, precision): """ $(addToPre, $(g)); """, - var_name_types=[("g", "scalar", VarAccess.READ_ONLY)]) + var_name_types=[("g", "scalar", SynapseVarAccess.READ_ONLY)]) model = GeNNModel(precision, "test_reverse_post", backend=backend) model.dt = 1.0 diff --git a/tests/features/test_spike_times.py b/tests/features/test_spike_times.py index d37d5207b0..68ff4689bc 100644 --- a/tests/features/test_spike_times.py +++ b/tests/features/test_spike_times.py @@ -4,7 +4,6 @@ from pygenn import GeNNModel -from pygenn.genn import SpanType, VarAccess from pygenn import (create_neuron_model, create_sparse_connect_init_snippet, create_var_init_snippet, diff --git a/tests/features/test_wu_vars.py b/tests/features/test_wu_vars.py index deeb6ea2c1..e972f99ce3 100644 --- a/tests/features/test_wu_vars.py +++ b/tests/features/test_wu_vars.py @@ -5,7 +5,7 @@ from pygenn import GeNNModel -from pygenn.genn import VarAccess +from pygenn.genn import NeuronVarAccess from pygenn import (create_neuron_model, create_weight_update_model, init_sparse_connectivity, init_var) @@ -220,7 +220,7 @@ def test_wu_var_cont(backend, precision, fuse, delay): pre_learn_post_weight_update_model = create_weight_update_model( "pre_learn_post_weight_update", var_name_types=[("w", "scalar")], - pre_var_name_types=[("s", "scalar"), ("shift", "scalar", VarAccess.READ_ONLY)], + pre_var_name_types=[("s", "scalar"), ("shift", "scalar", NeuronVarAccess.READ_ONLY)], learn_post_code= """ @@ -234,7 +234,7 @@ def test_wu_var_cont(backend, precision, fuse, delay): pre_sim_weight_update_model = create_weight_update_model( "pre_sim_weight_update", var_name_types=[("w", "scalar")], - pre_var_name_types=[("s", "scalar"), ("shift", "scalar", VarAccess.READ_ONLY)], + pre_var_name_types=[("s", "scalar"), ("shift", "scalar", NeuronVarAccess.READ_ONLY)], sim_code= """ @@ -251,7 +251,7 @@ def test_wu_var_cont(backend, precision, fuse, delay): post_learn_post_weight_update_model = create_weight_update_model( "post_learn_post_weight_update", var_name_types=[("w", "scalar")], - post_var_name_types=[("s", "scalar"), ("shift", "scalar", VarAccess.READ_ONLY)], + post_var_name_types=[("s", "scalar"), ("shift", "scalar", NeuronVarAccess.READ_ONLY)], learn_post_code= """ @@ -265,7 +265,7 @@ def test_wu_var_cont(backend, precision, fuse, delay): post_sim_weight_update_model = create_weight_update_model( "post_sim_weight_update", var_name_types=[("w", "scalar")], - post_var_name_types=[("s", "scalar"), ("shift", "scalar", VarAccess.READ_ONLY)], + post_var_name_types=[("s", "scalar"), ("shift", "scalar", NeuronVarAccess.READ_ONLY)], sim_code= """ From 3874ed1502f13f74a4b70c5f7eed73f844293fc7 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 29 Aug 2023 16:43:49 +0100 Subject: [PATCH 19/60] fixed warning in test --- tests/unit/typeChecker.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 859e6c5392..b3ab09f11b 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -109,7 +109,7 @@ class TestLibraryEnvironment : public TypeChecker::EnvironmentBase throw TypeChecker::TypeCheckError(); } - virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final + virtual std::vector getTypes(const Token &name, ErrorHandlerBase&) final { const auto [typeBegin, typeEnd] = m_Library.get().equal_range(name.lexeme); if (typeBegin == typeEnd) { From a611308a4bfe3d30d28c13fc5030d1aa5954cb65 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 29 Aug 2023 16:44:25 +0100 Subject: [PATCH 20/60] fixed some more warnings --- include/genn/genn/code_generator/environment.h | 2 +- src/genn/genn/code_generator/backendBase.cc | 10 +++++----- src/genn/genn/code_generator/initGroupMerged.cc | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index bc1bd7a581..4e057c8fc6 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -772,7 +772,7 @@ class VarRefCachePolicy //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - bool shouldAlwaysCopy(G&, const Models::Base::VarRef &var) const + bool shouldAlwaysCopy(G&, const Models::Base::VarRef&) const { // **NOTE** something else is managing the actual variables // and is therefore responsible for copying between delay slots etc diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 8cd6598d62..90a2538cb1 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -411,7 +411,7 @@ void buildStandardCustomUpdateEnvironment(const BackendBase &backend, Environmen } //-------------------------------------------------------------------------- template -void buildStandardCustomUpdateWUEnvironment(const BackendBase &backend, EnvironmentGroupMergedField &env, unsigned int batchSize) +void buildStandardCustomUpdateWUEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) { // Add batch offset if group is batched if((env.getGroup().getArchetype().getDims() & VarAccessDim::BATCH) && (batchSize > 1)) { @@ -579,12 +579,12 @@ void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const { - buildStandardCustomUpdateWUEnvironment(*this, env, batchSize); + buildStandardCustomUpdateWUEnvironment(env, batchSize); } //----------------------------------------------------------------------- void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const { - buildStandardCustomUpdateWUEnvironment(*this, env, batchSize); + buildStandardCustomUpdateWUEnvironment(env, batchSize); } //----------------------------------------------------------------------- void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const @@ -611,13 +611,13 @@ void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const { buildCustomUpdateWUSizeEnvironment(*this, env); - buildStandardCustomUpdateWUEnvironment(*this, env, batchSize); + buildStandardCustomUpdateWUEnvironment(env, batchSize); } //----------------------------------------------------------------------- void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const { buildCustomUpdateWUSizeEnvironment(*this, env); - buildStandardCustomUpdateWUEnvironment(*this, env, batchSize); + buildStandardCustomUpdateWUEnvironment(env, batchSize); } //----------------------------------------------------------------------- void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index f5fbc0280b..c358b00312 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -997,7 +997,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdatePreInitGroupMerg return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomConnectivityUpdatePreInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) +void CustomConnectivityUpdatePreInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int) { genInitNeuronVarCode(backend, env, *this, "", "size", 0, 1); } @@ -1026,7 +1026,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdatePostInitGroupMer return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomConnectivityUpdatePostInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) +void CustomConnectivityUpdatePostInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int) { // Initialise presynaptic custom connectivity update variables genInitNeuronVarCode(backend, env, *this, "", "size", 0, 1); From 8af209371785d7b845d6e3d4b253281c32bffe13 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 29 Aug 2023 17:19:49 +0100 Subject: [PATCH 21/60] fixed typo in NeuronVarAccess::READ_ONLY and NeuronVarAccess::READ_ONLY_DUPLICATE --- include/genn/genn/varAccess.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/genn/genn/varAccess.h b/include/genn/genn/varAccess.h index 47715f8559..5a2e58b4fa 100644 --- a/include/genn/genn/varAccess.h +++ b/include/genn/genn/varAccess.h @@ -45,8 +45,8 @@ enum class VarAccessDim : unsigned int enum class NeuronVarAccess : unsigned int { READ_WRITE = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDim::NEURON) | static_cast(VarAccessDim::BATCH), - READ_ONLY = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::NEURON) | static_cast(VarAccessDim::BATCH), - READ_ONLY_DUPLICATE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::NEURON), + READ_ONLY = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::NEURON), + READ_ONLY_DUPLICATE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::NEURON) | static_cast(VarAccessDim::BATCH), READ_ONLY_SHARED_NEURON = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::BATCH), }; From 7047f0bbabb95a1697eb72a3bcd680cd008f2b78 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 29 Aug 2023 17:20:27 +0100 Subject: [PATCH 22/60] add clarifying comment to unit test --- tests/unit/customUpdate.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/customUpdate.cc b/tests/unit/customUpdate.cc index ce9d6a7c77..d87813a8c7 100644 --- a/tests/unit/customUpdate.cc +++ b/tests/unit/customUpdate.cc @@ -489,7 +489,7 @@ TEST(CustomUpdates, BatchingVars) VarValues izkVarVals{{"V", 0.0}, {"U", 0.0}, {"a", 0.02}, {"b", 0.2}, {"c", -65.0}, {"d", 8.0}}; auto *pop = model.addNeuronPopulation("Pop", 10, {}, izkVarVals); - // Create updates where variable is shared and references vary + // Create updates where variable has same dimensionality as references but dimensionality varies VarValues sumVarValues{{"sum", 1.0}}; VarReferences sumVarReferences1{{"a", createVarRef(pop, "V")}, {"b", createVarRef(pop, "U")}}; VarReferences sumVarReferences2{{"a", createVarRef(pop, "a")}, {"b", createVarRef(pop, "b")}}; From 8253564157d8c5dc10d8da07901087e3a95a4b3b Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 29 Aug 2023 17:21:11 +0100 Subject: [PATCH 23/60] fixed logic in ``VarReference::getDims`` and ``WUVarReference::getDims`` so correct defaults will be used --- src/genn/genn/models.cc | 50 +++++++++++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index a4a278d6e9..f079fabb49 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -75,13 +75,32 @@ std::string VarReference::getTargetName() const //---------------------------------------------------------------------------- VarAccessDim VarReference::getDims() const { - const VarAccessDim varDims = getVar().access.getDims(); + const auto &varAccess = getVar().access; return std::visit( Utils::Overload{ - [varDims](const CURef &ref) { return clearDim(ref.group->getDims(), varDims); }, - [varDims](const CCUPreRef&){ return clearDim(varDims, VarAccessDim::BATCH); }, - [varDims](const CCUPostRef&){ return clearDim(varDims, VarAccessDim::BATCH); }, - [varDims](const auto&) { return varDims; }}, + // If reference is to a custom update variable, + // remove dimensions from those of update + [&varAccess](const CURef &ref) + { + return clearDim(ref.group->getDims(), + varAccess.getDims()); + }, + // Otherwise, if reference is to the presynaptic variables of a custom connectivity update, + // remove BATCH dimension as these are never batched + [&varAccess](const CCUPreRef&) + { + return clearDim(varAccess.getDims(), + VarAccessDim::BATCH); + }, + // Otherwise, if reference is to the postsynaptic variables of a custom connectivity update, + // remove BATCH dimension as these are never batched + [&varAccess](const CCUPostRef&) + { + return clearDim(varAccess.getDims(), + VarAccessDim::BATCH); + }, + // Otherwise, use dimensionality directly + [&varAccess](const auto&) { return varAccess.getDims(); }}, m_Detail); } //---------------------------------------------------------------------------- @@ -171,12 +190,25 @@ std::string WUVarReference::getTargetName() const //---------------------------------------------------------------------------- VarAccessDim WUVarReference::getDims() const { - const VarAccessDim varDims = getVar().access.getDims(); + const auto &varAccess = getVar().access; return std::visit( Utils::Overload{ - [varDims](const CURef &ref) { return clearDim(ref.group->getDims(), varDims); }, - [varDims](const CCURef&) { return clearDim(varDims, VarAccessDim::BATCH); }, - [varDims](const WURef&) { return varDims; }}, + // If reference is to a custom update variable, + // remove dimensions from those of update + [&varAccess](const CURef &ref) + { + return clearDim(ref.group->getDims(), + varAccess.getDims()); + }, + // Otherwise, if reference is to the synaptic variables of a custom connectivity update, + // remove BATCH dimension as these are never batched + [&varAccess](const CCURef&) + { + return clearDim(varAccess.getDims(), + VarAccessDim::BATCH); + }, + // Otherwise, use dimensionality directly + [&varAccess](const WURef&){ return varAccess.getDims(); }}, m_Detail); } //---------------------------------------------------------------------------- From b0fd6702bbc9d4d6038ba99b4cc52942f9e65755 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 29 Aug 2023 17:35:55 +0100 Subject: [PATCH 24/60] fixed typo which meant death test was dying for incorrect reason! --- tests/unit/customUpdate.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/unit/customUpdate.cc b/tests/unit/customUpdate.cc index d87813a8c7..b282dbb3aa 100644 --- a/tests/unit/customUpdate.cc +++ b/tests/unit/customUpdate.cc @@ -465,7 +465,7 @@ TEST(CustomUpdates, WUVarSynapseGroupChecks) VarValues sumVarValues{{"sum", 0.0}}; WUVarReferences sumVarReferences1{{"a", createWUVarRef(sg1, "g")}, {"b", createWUVarRef(sg1, "g")}}; - WUVarReferences sumVarReferences2{{"a", createWUVarRef(sg1, "g")}, {"b", createWUVarRef(sg2, "d")}}; + WUVarReferences sumVarReferences2{{"a", createWUVarRef(sg1, "g")}, {"b", createWUVarRef(sg2, "g")}}; model.addCustomUpdate("SumWeight1", "CustomUpdate", {}, sumVarValues, sumVarReferences1); @@ -547,7 +547,7 @@ TEST(CustomUpdates, BatchingWriteShared) VarValues izkVarVals{{"V", 0.0}, {"U", 0.0}, {"a", 0.02}, {"b", 0.2}, {"c", -65.0}, {"d", 8.0}}; auto *pop = model.addNeuronPopulation("Pop", 10, {}, izkVarVals); - // Create custom update which tries to create a read-write refernece to a (which isn't batched) + // Create custom update which tries to create a read-write reference to a (which isn't batched) VarReferences reduceVarReferences{{"var", createVarRef(pop, "V")}, {"reduction", createVarRef(pop, "U")}}; try { model.addCustomUpdate("Sum1", "CustomUpdate", @@ -570,9 +570,10 @@ TEST(CustomUpdates, WriteNeuronShared) // Create custom update which tries to create a read-write reference to a (which isn't per-neuron) VarValues sum2VarValues{{"mult", 1.0}}; VarReferences sum2VarReferences{{"a", createVarRef(pop, "a")}, {"b", createVarRef(pop, "V")}}; + model.addCustomUpdate("Sum1", "CustomUpdate", + {}, sum2VarValues, sum2VarReferences); try { - model.addCustomUpdate("Sum1", "CustomUpdate", - {}, sum2VarValues, sum2VarReferences); + model.finalise(); FAIL(); } catch(const std::runtime_error &) { From 8a21a5fb0348129fa5c3708bec3f5916ce9031f3 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 29 Aug 2023 17:54:30 +0100 Subject: [PATCH 25/60] Fixed a couple of issues in ``CustomUpdateBase::isReduction`` * Variables always have ``CustomUpdateVarAccess`` access type * ``getDims`` should be called on variable reference not directly on underlying variable --- include/genn/genn/customUpdate.h | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index 903276e82d..e96be26f29 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -85,17 +85,18 @@ class GENN_EXPORT CustomUpdateBase const std::vector getUpdateCodeTokens() const{ return m_UpdateCodeTokens; } - template + template bool isReduction(const std::unordered_map &varRefs, VarAccessDim reduceDim) const { - // Return true if any variables have REDUCE flag in their access mode and have reduction dimension + // Return true if any variables have REDUCE flag in their access mode and have reduction dimension + // **NOTE** this is correct because custom update variable access types are defined subtractively const auto vars = getCustomUpdateModel()->getVars(); if(std::any_of(vars.cbegin(), vars.cend(), [reduceDim](const Models::Base::Var &v) { return ((v.access & VarAccessModeAttribute::REDUCE) - && (v.access.template getDims() & reduceDim)); + && (v.access.getDims() & reduceDim)); })) { return true; @@ -104,10 +105,10 @@ class GENN_EXPORT CustomUpdateBase // Loop through all variable references for(const auto &modelVarRef : getCustomUpdateModel()->getVarRefs()) { // If custom update model reduces into this variable reference - // and the variable it targets has reduction dimension + // and the variable it targets doesn't have reduction dimension const auto &varRef = varRefs.at(modelVarRef.name); if ((modelVarRef.access & VarAccessModeAttribute::REDUCE) - && (varRef.getVar().access.template getDims() & reduceDim)) + && !(varRef.getDims() & reduceDim)) { return true; } @@ -263,8 +264,8 @@ class GENN_EXPORT CustomUpdate : public CustomUpdateBase //------------------------------------------------------------------------ // Protected const methods //------------------------------------------------------------------------ - bool isBatchReduction() const { return isReduction(getVarReferences(), VarAccessDim::BATCH); } - bool isNeuronReduction() const { return isReduction(getVarReferences(), VarAccessDim::NEURON); } + bool isBatchReduction() const { return isReduction(getVarReferences(), VarAccessDim::BATCH); } + bool isNeuronReduction() const { return isReduction(getVarReferences(), VarAccessDim::NEURON); } const NeuronGroup *getDelayNeuronGroup() const { return m_DelayNeuronGroup; } @@ -316,7 +317,7 @@ class GENN_EXPORT CustomUpdateWU : public CustomUpdateBase //------------------------------------------------------------------------ // Protected const methods //------------------------------------------------------------------------ - bool isBatchReduction() const { return isReduction(getVarReferences(), VarAccessDim::BATCH); } + bool isBatchReduction() const { return isReduction(getVarReferences(), VarAccessDim::BATCH); } bool isTransposeOperation() const; SynapseGroupInternal *getSynapseGroup() const { return m_SynapseGroup; } From da8f180aa0750aee4d64a49e7c474b60c3f12ae9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 29 Aug 2023 18:03:42 +0100 Subject: [PATCH 26/60] infact, we should ALWAYS call getDims directly on variable references rather than on Var --- include/genn/genn/customUpdate.h | 4 +- .../backends/single_threaded_cpu/backend.cc | 4 +- src/genn/genn/code_generator/backendBase.cc | 6 +- .../customConnectivityUpdateGroupMerged.cc | 59 +------------------ .../code_generator/customUpdateGroupMerged.cc | 2 +- src/genn/genn/customConnectivityUpdate.cc | 4 +- src/genn/genn/customUpdate.cc | 32 +--------- 7 files changed, 15 insertions(+), 96 deletions(-) diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index e96be26f29..a1f6c154a7 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -118,7 +118,7 @@ class GENN_EXPORT CustomUpdateBase } //! Helper function to check if variable reference types match those specified in model - template + template void checkVarReferenceDims(const std::unordered_map& varRefs, unsigned int batchSize) { // Loop through variable references and or together their dimensions to get dimensionality of update @@ -133,7 +133,7 @@ class GENN_EXPORT CustomUpdateBase // If the shape of the references variable doesn't match the dimensionality // of the custom update, check its access mode isn't read-write - if((m_Dims != varRef.getVar().access.template getDims()) + if((m_Dims != varRef.getDims()) && (modelVarRef.access == VarAccessMode::READ_WRITE)) { throw std::runtime_error("Variable references to lower-dimensional variables cannot be read-write."); diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 722d3ddcae..d4a14e3546 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -2018,7 +2018,7 @@ void Backend::genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateG [&cg](const Models::VarReference &varRef, const std::string &index) { return cg.getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, 1, - varRef.getVar().access.getDims(), index); + varRef.getDims(), index); }); } //-------------------------------------------------------------------------- @@ -2028,7 +2028,7 @@ void Backend::genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateW env, cg, idxName, [&cg](const Models::WUVarReference &varRef, const std::string &index) { - return cg.getVarRefIndex(1, varRef.getVar().access.getDims(), index); + return cg.getVarRefIndex(1, varRef.getDims(), index); }); } } // namespace GeNN::CodeGenerator::SingleThreadedCPU diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 90a2538cb1..1fbae498d8 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -706,18 +706,18 @@ std::vector BackendBase::genInitReductionTargets(C [batchSize, &cg](const Models::VarReference &varRef, const std::string &index) { return cg.getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, batchSize, - varRef.getVar().access.getDims(), index); + varRef.getDims(), index); }); } //----------------------------------------------------------------------- std::vector BackendBase::genInitReductionTargets(CodeStream &os, const CustomUpdateWUGroupMerged &cg, unsigned int batchSize, const std::string &idx) const { - return genInitReductionTargets( + return genInitReductionTargets( os, cg, batchSize, idx, [batchSize, &cg](const Models::WUVarReference &varRef, const std::string &index) { - return cg.getVarRefIndex(batchSize, varRef.getVar().access.getDims(), index); + return cg.getVarRefIndex(batchSize, varRef.getDims(), index); }); } } // namespace GeNN::CodeGenerator \ No newline at end of file diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 3544ff2b7e..d68f4fc38f 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -37,11 +37,11 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t { boost::uuids::detail::sha1 hashA; Type::updateHash(a.getVar().type, hashA); - Utils::updateHash(a.getVar().access.template getDims(), hashA); + Utils::updateHash(a.getDims(), hashA); boost::uuids::detail::sha1 hashB; Type::updateHash(b.getVar().type, hashB); - Utils::updateHash(b.getVar().access.template getDims(), hashB); + Utils::updateHash(b.getDims(), hashB); return (hashA.get_digest() < hashB.get_digest()); }); @@ -56,61 +56,6 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t { return (vars.size() == m_SortedDependentVars.front().size()); })); - - - /*addField(Uint32, "rowStride", - [&backend](const auto &cg, size_t) - { - const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); - return std::to_string(backend.getSynapticMatrixRowStride(*sgInternal)); - }); - - - assert(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE); - addField(getArchetype().getSynapseGroup()->getSparseIndType().createPointer(), "ind", - [&backend](const auto &cg, size_t) - { - return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName(); - }); - - addField(Uint32.createPointer(), "rowLength", - [&backend](const auto &cg, size_t) - { - return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName(); - }); - - // If some presynaptic variables are delayed, add delay pointer - if (getArchetype().getPreDelayNeuronGroup() != nullptr) { - addField(Uint32.createPointer(), "preSpkQuePtr", - [&backend](const auto &cg, size_t) - { - return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getPreDelayNeuronGroup()->getName(); - }); - } - - // If some postsynaptic variables are delayed, add delay pointer - if (getArchetype().getPostDelayNeuronGroup() != nullptr) { - addField(Uint32.createPointer(), "postSpkQuePtr", - [&backend](const auto &cg, size_t) - { - return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getPostDelayNeuronGroup()->getName(); - }); - } - - - // Add variables to struct - - - // Loop through sorted dependent variables - for(size_t i = 0; i < getSortedArchetypeDependentVars().size(); i++) { - auto resolvedType = getSortedArchetypeDependentVars().at(i).getVar().type.resolve(getTypeContext()); - addField(resolvedType.createPointer(), "_dependentVar" + std::to_string(i), - [i, &backend, this](const auto&, size_t g) - { - const auto &varRef = m_SortedDependentVars[g][i]; - return backend.getDeviceVarPrefix() + varRef.getVar().name + varRef.getTargetName(); - }); - }*/ } //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type CustomConnectivityUpdateGroupMerged::getHashDigest() const diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index ae9022c9a1..2ff12b9873 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -71,7 +71,7 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E [this, batchSize, &varEnv](const std::string&, const Models::VarReference &v) { return getVarRefIndex(v.getDelayNeuronGroup() != nullptr, batchSize, - v.getVar().access.getDims(), "$(id)"); + v.getDims(), "$(id)"); }); Transpiler::ErrorHandler errorHandler("Custom update '" + getArchetype().getName() + "' update code"); diff --git a/src/genn/genn/customConnectivityUpdate.cc b/src/genn/genn/customConnectivityUpdate.cc index 2ad273055a..fdee52420c 100644 --- a/src/genn/genn/customConnectivityUpdate.cc +++ b/src/genn/genn/customConnectivityUpdate.cc @@ -180,7 +180,7 @@ void CustomConnectivityUpdate::finalise(double dt, unsigned int batchSize) if (std::any_of(getPreVarReferences().cbegin(), getPreVarReferences().cend(), [](const auto &v) { - return (v.second.getVar().access.template getDims() & VarAccessDim::BATCH); + return (v.second.getDims() & VarAccessDim::BATCH); })) { throw std::runtime_error("Presynaptic variables referenced by CustomConnectivityUpdate must be SHARED across batches"); @@ -190,7 +190,7 @@ void CustomConnectivityUpdate::finalise(double dt, unsigned int batchSize) if (std::any_of(getPostVarReferences().cbegin(), getPostVarReferences().cend(), [](const auto &v) { - return (v.second.getVar().access.template getDims() & VarAccessDim::BATCH); + return (v.second.getDims() & VarAccessDim::BATCH); })) { throw std::runtime_error("Postsynaptic variables referenced by CustomConnectivityUpdate must be SHARED across batches"); diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index a5130ddc80..64bb15b6af 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -134,32 +134,6 @@ CustomUpdate::CustomUpdate(const std::string &name, const std::string &updateGro // Check variable reference types Models::checkVarReferenceTypes(m_VarReferences, getCustomUpdateModel()->getVarRefs()); - // Update is per-neuron if any variables or variable reference targets have neuron dimension - /*const auto modelVars = getCustomUpdateModel()->getVars(); - m_PerNeuron = std::any_of(m_VarReferences.cbegin(), m_VarReferences.cend(), - [](const auto& v) - { - return (v.second.getVar().access.template getDims() & VarAccessDim::NEURON); - }); - m_PerNeuron |= std::any_of(modelVars.cbegin(), modelVars.cend(), - [](const Models::Base::Var& v) - { - return (v.access.template getDims() & VarAccessDim::NEURON); - }); - - // Loop through all variable references - for(const auto &modelVarRef : getCustomUpdateModel()->getVarRefs()) { - const auto &varRef = m_VarReferences.at(modelVarRef.name); - - // If custom update is per-neuron, check that any variable references to variables without NEURON axis are read-only - // **NOTE** if custom update isn't per-neuron, it's totally fine to write to SHARED_NEURON variables - if(m_PerNeuron && !(varRef.getVar().access.getDims() & VarAccessDim::NEURON) - && (modelVarRef.access == VarAccessMode::READ_WRITE)) - { - throw std::runtime_error("Variable references to SHARED_NEURON variables in per-neuron custom updates cannot be read-write."); - } - } - */ // Check only one type of reduction is specified if (isBatchReduction() && isNeuronReduction()) { throw std::runtime_error("Custom updates cannot perform batch and neuron reductions simultaneously."); @@ -179,7 +153,7 @@ void CustomUpdate::finalise(double dt, unsigned int batchSize) CustomUpdateBase::finalise(dt); // Check variable reference batching - checkVarReferenceDims(m_VarReferences, batchSize); + checkVarReferenceDims(m_VarReferences, batchSize); // If any variable references have delays auto delayRef = std::find_if(m_VarReferences.cbegin(), m_VarReferences.cend(), @@ -218,7 +192,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdate::getHashDigest() const Utils::updateHash((v.second.getDelayNeuronGroup() == nullptr), hash); // Update hash with target variable dimensions as this effects indexing code - Utils::updateHash(v.second.getVar().access.getDims(), hash); + Utils::updateHash(v.second.getDims(), hash); } return hash.get_digest(); } @@ -295,7 +269,7 @@ void CustomUpdateWU::finalise(double dt, unsigned int batchSize) CustomUpdateBase::finalise(dt); // Check variable reference types - checkVarReferenceDims(m_VarReferences, batchSize); + checkVarReferenceDims(m_VarReferences, batchSize); } //---------------------------------------------------------------------------- bool CustomUpdateWU::isTransposeOperation() const From 608c5e66a684b62f135aa22efd30750760bccf7a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 30 Aug 2023 18:11:04 +0100 Subject: [PATCH 27/60] Correct resolution of dimensions in runner and initialisation * moved resolution of variable dimensions into adaptor classes * used this in generic initialisation and runner code * tidied up variable size calculation in runner --- include/genn/genn/currentSourceInternal.h | 2 + .../genn/customConnectivityUpdateInternal.h | 6 + include/genn/genn/customUpdate.h | 11 +- include/genn/genn/neuronGroupInternal.h | 2 + include/genn/genn/synapseGroupInternal.h | 9 ++ .../genn/code_generator/generateRunner.cc | 130 ++++++++++-------- .../genn/code_generator/initGroupMerged.cc | 6 +- 7 files changed, 104 insertions(+), 62 deletions(-) diff --git a/include/genn/genn/currentSourceInternal.h b/include/genn/genn/currentSourceInternal.h index 05a3d50457..b6c7277500 100644 --- a/include/genn/genn/currentSourceInternal.h +++ b/include/genn/genn/currentSourceInternal.h @@ -54,6 +54,8 @@ class CurrentSourceVarAdapter const std::string &getNameSuffix() const{ return m_CS.getName(); } + VarAccessDim getVarDims(const Models::Base::Var &var) const{ return var.access.getDims(); } + private: //---------------------------------------------------------------------------- // Members diff --git a/include/genn/genn/customConnectivityUpdateInternal.h b/include/genn/genn/customConnectivityUpdateInternal.h index e603501cf6..da6ebb2e5d 100644 --- a/include/genn/genn/customConnectivityUpdateInternal.h +++ b/include/genn/genn/customConnectivityUpdateInternal.h @@ -61,6 +61,8 @@ class CustomConnectivityUpdateVarAdapter const std::string &getNameSuffix() const{ return m_CU.getName(); } + VarAccessDim getVarDims(const Models::Base::Var &var) const{ return var.access.getDims(); } + private: //---------------------------------------------------------------------------- // Members @@ -90,6 +92,8 @@ class CustomConnectivityUpdatePreVarAdapter const std::string &getNameSuffix() const{ return m_CU.getName(); } + VarAccessDim getVarDims(const Models::Base::Var &var) const{ return var.access.getDims(); } + private: //---------------------------------------------------------------------------- // Members @@ -119,6 +123,8 @@ class CustomConnectivityUpdatePostVarAdapter const std::string &getNameSuffix() const{ return m_CU.getName(); } + VarAccessDim getVarDims(const Models::Base::Var &var) const{ return var.access.getDims(); } + private: //---------------------------------------------------------------------------- // Members diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index a1f6c154a7..d6d947ac99 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -50,6 +50,9 @@ class GENN_EXPORT CustomUpdateBase //! Is var init code required for any variables in this custom update group's custom update model? bool isVarInitRequired() const; + //! Get dimensions of this custom update + VarAccessDim getDims() const{ return m_Dims; } + protected: CustomUpdateBase(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, @@ -70,9 +73,6 @@ class GENN_EXPORT CustomUpdateBase bool isZeroCopyEnabled() const; - //! Get dimensions of this custom update - VarAccessDim getDims() const{ return m_Dims; } - //! Updates hash with custom update /*! NOTE: this can only be called after model is finalized */ void updateHash(boost::uuids::detail::sha1 &hash) const; @@ -208,6 +208,11 @@ class CustomUpdateVarAdapter const std::string &getNameSuffix() const{ return m_CU.getName(); } + VarAccessDim getVarDims(const Models::Base::Var &var) const + { + return clearDim(m_CU.getDims(), var.access.getDims()); + } + private: //---------------------------------------------------------------------------- // Members diff --git a/include/genn/genn/neuronGroupInternal.h b/include/genn/genn/neuronGroupInternal.h index 5905735aad..5120b5c482 100644 --- a/include/genn/genn/neuronGroupInternal.h +++ b/include/genn/genn/neuronGroupInternal.h @@ -78,6 +78,8 @@ class NeuronVarAdapter const std::string &getNameSuffix() const{ return m_NG.getName(); } + VarAccessDim getVarDims(const Models::Base::Var &var) const{ return var.access.getDims(); } + private: //---------------------------------------------------------------------------- // Members diff --git a/include/genn/genn/synapseGroupInternal.h b/include/genn/genn/synapseGroupInternal.h index c1cc18823c..7031994651 100644 --- a/include/genn/genn/synapseGroupInternal.h +++ b/include/genn/genn/synapseGroupInternal.h @@ -118,6 +118,8 @@ class SynapsePSMVarAdapter bool isVarDelayed(const std::string &) const { return false; } + VarAccessDim getVarDims(const Models::Base::Var &var) const{ return var.access.getDims(); } + private: //---------------------------------------------------------------------------- // Members @@ -167,6 +169,9 @@ class SynapseWUVarAdapter const std::unordered_map &getInitialisers() const{ return m_SG.getWUVarInitialisers(); } const std::string &getNameSuffix() const{ return m_SG.getName(); } + + VarAccessDim getVarDims(const Models::Base::Var &var) const{ return var.access.getDims(); } + private: //---------------------------------------------------------------------------- // Members @@ -196,6 +201,8 @@ class SynapseWUPreVarAdapter bool isVarDelayed(const std::string&) const{ return (m_SG.getDelaySteps() != 0); } + VarAccessDim getVarDims(const Models::Base::Var &var) const{ return var.access.getDims(); } + private: //---------------------------------------------------------------------------- // Members @@ -225,6 +232,8 @@ class SynapseWUPostVarAdapter bool isVarDelayed(const std::string&) const{ return (m_SG.getBackPropDelaySteps() != 0); } + VarAccessDim getVarDims(const Models::Base::Var &var) const{ return var.access.getDims(); } + private: //---------------------------------------------------------------------------- // Members diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 430868c94f..4e9aaf3409 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -27,20 +27,45 @@ using namespace GeNN::CodeGenerator; //-------------------------------------------------------------------------- namespace { -unsigned int getNumVarCopies(VarAccessDim varDims, unsigned int batchSize, bool batched = true) +size_t getNumVarCopies(VarAccessDim varDims, size_t batchSize, bool batched = true) { return ((varDims & VarAccessDim::BATCH) && batched) ? batchSize : 1; } //-------------------------------------------------------------------------- -unsigned int getNumVarElements(VarAccessDim varDims, unsigned int numNeurons) +size_t getNumNeuronVarElements(VarAccessDim varDims, size_t numNeurons) { return (varDims & VarAccessDim::NEURON) ? numNeurons : 1; } //-------------------------------------------------------------------------- -unsigned int getVarSize(VarAccessDim varDims, unsigned int numElements, unsigned int batchSize, - unsigned int delaySlots = 1, bool batched = true) +size_t getNeuronVarSize(VarAccessDim varDims, size_t numElements, size_t batchSize, + size_t delaySlots = 1, bool batched = true) { - return getNumVarCopies(varDims, batchSize, batched) * getNumVarElements(varDims, numElements) * delaySlots; + return getNumVarCopies(varDims, batchSize, batched) * getNumNeuronVarElements(varDims, numElements) * delaySlots; +} +//-------------------------------------------------------------------------- +size_t getSynapseVarSize(VarAccessDim varDims, size_t numPre, size_t rowStride, size_t batchSize, + std::optional kernelSize = std::nullopt, bool batched = true) +{ + const size_t numCopies = getNumVarCopies(varDims, batchSize, batched); + const bool pre = (varDims & VarAccessDim::PRE_NEURON); + const bool post = (varDims & VarAccessDim::POST_NEURON); + if(pre && post) { + if(kernelSize) { + return kernelSize.value() * numCopies; + } + else { + return numPre * rowStride * numCopies; + } + } + else if(varDims & VarAccessDim::PRE_NEURON) { + return numPre * numCopies; + } + else if(varDims & VarAccessDim::POST_NEURON) { + return rowStride * numCopies; + } + else { + return numCopies; + } } //-------------------------------------------------------------------------- void genSpikeMacros(CodeStream &os, const NeuronGroupInternal &ng, bool trueSpike) @@ -386,7 +411,7 @@ void genRunnerVars(const ModelSpecMerged &modelMerged, const BackendBase &backen const auto resolvedType = var.type.resolve(modelMerged.getModel().getTypeContext()); genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerPushFunc, runnerPullFunc, resolvedType, var.name + group.getName(), varAdaptor.getLoc(var.name), - autoInitialized, getSizeFn(group, var), mem, statePushPullFunctions); + autoInitialized, getSizeFn(group, varAdaptor.getVarDims(var)), mem, statePushPullFunctions); // Loop through EGPs required to initialize variable for(const auto &egp : varInit.getSnippet()->getExtraGlobalParams()) { @@ -410,7 +435,7 @@ void genRunnerFusedVars(const ModelSpecMerged &modelMerged, const BackendBase &b const auto resolvedType = var.type.resolve(modelMerged.getModel().getTypeContext()); backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, resolvedType, var.name + varAdaptor.getNameSuffix(), varAdaptor.getLoc(var.name), - getSizeFn(group, var), mem); + getSizeFn(group, varAdaptor.getVarDims(var)), mem); // Loop through EGPs required to initialize variable for(const auto &egp : varAdaptor.getInitialisers().at(var.name).getSnippet()->getExtraGlobalParams()) { @@ -438,7 +463,8 @@ void genRunnerFusedVarPushPull(const ModelSpecMerged &modelMerged, const Backend { backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, resolvedType, var.name + group.getName(), - varAdaptor.getLoc(var.name), autoInitialized, getSizeFn(group, var)); + varAdaptor.getLoc(var.name), autoInitialized, + getSizeFn(group, varAdaptor.getVarDims(var))); }); } } @@ -1082,7 +1108,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, for(const auto &var : neuronModel->getVars()) { const auto &varInit = n.second.getVarInitialisers().at(var.name); const unsigned int numCopies = getNumVarCopies(var.access.getDims(), batchSize); - const unsigned int numElements = getNumVarElements(var.access.getDims(), n.second.getNumNeurons()); + const unsigned int numElements = getNumNeuronVarElements(var.access.getDims(), n.second.getNumNeurons()); const size_t count = n.second.isVarQueueRequired(var.name) ? numCopies * numElements * n.second.getNumDelaySlots() : numCopies * numElements; const bool autoInitialized = !Utils::areTokensEmpty(varInit.getCodeTokens()); const auto resolvedType = var.type.resolve(modelMerged.getModel().getTypeContext()); @@ -1150,10 +1176,9 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, genRunnerVars(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerExtraGlobalParamFunc, runnerPushFunc, runnerPullFunc, *cs, mem, currentSourcePushPullFunctions, - [batchSize, &n](const CurrentSourceInternal&, const Models::Base::Var &var) + [batchSize, &n](const CurrentSourceInternal&, VarAccessDim varDims) { - return getVarSize(var.access.getDims(), - n.second.getNumNeurons(), batchSize); + return getNeuronVarSize(varDims, n.second.getNumNeurons(), batchSize); }); genRunnerEGPs(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerExtraGlobalParamFunc, *cs); @@ -1174,10 +1199,10 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerExtraGlobalParamFunc, runnerPushFunc, runnerPullFunc, model.getCustomUpdates(), mem, statePushPullFunctions, - [batchSize](const CustomUpdateInternal &c, const Models::Base::Var &var) + [batchSize](const CustomUpdateInternal &c, VarAccessDim varDims) { - return getVarSize(var.access.getDims(), - c.getSize(), batchSize, 1, c.getDims() & VarAccessDim::BATCH); + return getNeuronVarSize(varDims, c.getSize(), batchSize, 1, + c.getDims() & VarAccessDim::BATCH); }); genCustomUpdate( @@ -1185,14 +1210,13 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerExtraGlobalParamFunc, runnerPushFunc, runnerPullFunc, model.getCustomWUUpdates(), mem, statePushPullFunctions, - [batchSize, &backend](const CustomUpdateWUInternal &c, const Models::Base::Var &var) + [batchSize, &backend](const CustomUpdateWUInternal &c, VarAccessDim varDims) { const SynapseGroupInternal *sg = c.getSynapseGroup(); - const size_t count = ((sg->getMatrixType() & SynapseMatrixWeight::KERNEL) - ? sg->getKernelSizeFlattened() - : sg->getSrcNeuronGroup()->getNumNeurons() * backend.getSynapticMatrixRowStride(*sg)); - return getVarSize(var.access.getDims(), count, - batchSize, 1, c.getDims() & VarAccessDim::BATCH); + return getSynapseVarSize(varDims, sg->getSrcNeuronGroup()->getNumNeurons(), + backend.getSynapticMatrixRowStride(*sg), batchSize, + (sg->getMatrixType() & SynapseMatrixWeight::KERNEL) ? std::make_optional(sg->getKernelSizeFlattened()) : std::nullopt, + c.getDims() & VarAccessDim::BATCH); }); allVarStreams << std::endl; @@ -1206,7 +1230,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, genRunnerVars(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerExtraGlobalParamFunc, runnerPushFunc, runnerPullFunc, c.second, mem, customConnectivityPushPullFunctions, - [&backend](const CustomConnectivityUpdateInternal &c, const Models::Base::Var&) + [&backend](const CustomConnectivityUpdateInternal &c, VarAccessDim) { const SynapseGroupInternal *sg = c.getSynapseGroup(); return (sg->getSrcNeuronGroup()->getNumNeurons() * backend.getSynapticMatrixRowStride(*sg)); @@ -1216,7 +1240,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, genRunnerVars(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerExtraGlobalParamFunc, runnerPushFunc, runnerPullFunc, c.second, mem, customConnectivityPushPullFunctions, - [](const CustomConnectivityUpdateInternal &c, const Models::Base::Var&) + [](const CustomConnectivityUpdateInternal &c, VarAccessDim) { return c.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(); }); @@ -1226,7 +1250,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, genRunnerVars(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerExtraGlobalParamFunc, runnerPushFunc, runnerPullFunc, c.second, mem, customConnectivityPushPullFunctions, - [&backend](const CustomConnectivityUpdateInternal &c, const Models::Base::Var&) + [&backend](const CustomConnectivityUpdateInternal &c, VarAccessDim) { return c.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons(); }); @@ -1271,10 +1295,10 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, genRunnerFusedVars(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerExtraGlobalParamFunc, *sg, mem, - [batchSize](const SynapseGroupInternal &sg, const Models::Base::Var &var) + [batchSize](const SynapseGroupInternal &sg, VarAccessDim varDims) { - return getVarSize(var.access.getDims(), sg.getTrgNeuronGroup()->getNumNeurons(), - batchSize); + return getNeuronVarSize(varDims, sg.getTrgNeuronGroup()->getNumNeurons(), + batchSize); }); } // Loop through fused outgoing synapse populations with weightupdate models that have presynaptic output @@ -1289,10 +1313,10 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const unsigned int preDelaySlots = (sg->getDelaySteps() == NO_DELAY) ? 1 : sg->getSrcNeuronGroup()->getNumDelaySlots(); genRunnerFusedVars(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerExtraGlobalParamFunc, *sg, mem, - [batchSize, preDelaySlots](const SynapseGroupInternal &sg, const Models::Base::Var &var) + [batchSize, preDelaySlots](const SynapseGroupInternal &sg, VarAccessDim varDims) { - return getVarSize(var.access.getDims(), sg.getSrcNeuronGroup()->getNumNeurons(), - batchSize, preDelaySlots); + return getNeuronVarSize(varDims, sg.getSrcNeuronGroup()->getNumNeurons(), + batchSize, preDelaySlots); }); } @@ -1301,10 +1325,10 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const unsigned int postDelaySlots = (sg->getBackPropDelaySteps() == NO_DELAY) ? 1 : sg->getTrgNeuronGroup()->getNumDelaySlots(); genRunnerFusedVars(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerExtraGlobalParamFunc, *sg, mem, - [batchSize, postDelaySlots](const SynapseGroupInternal &sg, const Models::Base::Var &var) + [batchSize, postDelaySlots](const SynapseGroupInternal &sg, VarAccessDim varDims) { - return getVarSize(var.access.getDims(), sg.getTrgNeuronGroup()->getNumNeurons(), - batchSize, postDelaySlots); + return getNeuronVarSize(varDims, sg.getTrgNeuronGroup()->getNumNeurons(), + batchSize, postDelaySlots); }); } } @@ -1404,23 +1428,17 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const auto &varInit = s.second.getWUVarInitialisers().at(wuVar.name); const bool autoInitialized = !Utils::areTokensEmpty(varInit.getCodeTokens()); const auto resolvedType = wuVar.type.resolve(modelMerged.getModel().getTypeContext()); - const unsigned int numCopies = getNumVarCopies(wuVar.access.getDims(), batchSize); - if(individualWeights) { - const size_t size = (size_t)s.second.getSrcNeuronGroup()->getNumNeurons() * (size_t)backend.getSynapticMatrixRowStride(s.second); + if(individualWeights || kernelWeights) { + const size_t size = getSynapseVarSize(wuVar.access.getDims(), + s.second.getSrcNeuronGroup()->getNumNeurons(), + backend.getSynapticMatrixRowStride(s.second), batchSize, + kernelWeights ? std::make_optional(s.second.getKernelSizeFlattened()) : std::nullopt); + genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerPushFunc, runnerPullFunc, resolvedType, wuVar.name + s.second.getName(), s.second.getWUVarLocation(wuVar.name), - autoInitialized, size * numCopies, mem, synapseGroupStatePushPullFunctions); + autoInitialized, size, mem, synapseGroupStatePushPullFunctions); } - else if(kernelWeights) { - // Calculate size of kernel - const size_t size = s.second.getKernelSizeFlattened() * numCopies; - - // Generate variable - genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - runnerPushFunc, runnerPullFunc, resolvedType, wuVar.name + s.second.getName(), s.second.getWUVarLocation(wuVar.name), - autoInitialized, size, mem, synapseGroupStatePushPullFunctions); - } - + // Loop through EGPs required to initialize WUM for(const auto &e : varInit.getSnippet()->getExtraGlobalParams()) { genExtraGlobalParam(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, @@ -1446,10 +1464,10 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, genRunnerFusedVarPushPull(modelMerged, backend, definitionsFunc, runnerPushFunc, runnerPullFunc, s.second, synapseGroupStatePushPullFunctions, - [batchSize](const SynapseGroupInternal &sg, const Models::Base::Var &var) + [batchSize](const SynapseGroupInternal &sg, VarAccessDim varDims) { - return getVarSize(var.access.getDims(), sg.getTrgNeuronGroup()->getNumNeurons(), - batchSize); + return getNeuronVarSize(varDims, sg.getTrgNeuronGroup()->getNumNeurons(), + batchSize); }); } @@ -1459,10 +1477,10 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const unsigned int preDelaySlots = (s.second.getDelaySteps() == NO_DELAY) ? 1 : s.second.getSrcNeuronGroup()->getNumDelaySlots(); genRunnerFusedVarPushPull(modelMerged, backend, definitionsFunc, runnerPushFunc, runnerPullFunc, s.second, synapseGroupStatePushPullFunctions, - [batchSize, preDelaySlots](const SynapseGroupInternal &sg, const Models::Base::Var &var) + [batchSize, preDelaySlots](const SynapseGroupInternal &sg, VarAccessDim varDims) { - return getVarSize(var.access.getDims(), sg.getSrcNeuronGroup()->getNumNeurons(), - batchSize, preDelaySlots); + return getNeuronVarSize(varDims, sg.getSrcNeuronGroup()->getNumNeurons(), + batchSize, preDelaySlots); }); } @@ -1473,10 +1491,10 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const unsigned int postDelaySlots = (s.second.getBackPropDelaySteps() == NO_DELAY) ? 1 : s.second.getTrgNeuronGroup()->getNumDelaySlots(); genRunnerFusedVarPushPull(modelMerged, backend, definitionsFunc, runnerPushFunc, runnerPullFunc, s.second, synapseGroupStatePushPullFunctions, - [batchSize, postDelaySlots](const SynapseGroupInternal &sg, const Models::Base::Var &var) + [batchSize, postDelaySlots](const SynapseGroupInternal &sg, VarAccessDim varDims) { - return getVarSize(var.access.getDims(), sg.getTrgNeuronGroup()->getNumNeurons(), - batchSize, postDelaySlots); + return getNeuronVarSize(varDims, sg.getTrgNeuronGroup()->getNumNeurons(), + batchSize, postDelaySlots); }); } diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index c358b00312..9e4e231f00 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -87,7 +87,7 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e }); // If variable has NEURON axis - const VarAccessDim varDims = var.access.template getDims(); + const VarAccessDim varDims = adaptor.getVarDims(var); if (varDims & VarAccessDim::NEURON) { backend.genVariableInit( varEnv, count, "id", @@ -171,7 +171,7 @@ void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, // Generate target-specific code to initialise variable genSynapseVariableRowInitFn(varEnv, - [&group, &resolvedType, &stride, &var, &varInit, batchSize] + [&adaptor, &group, &resolvedType, &stride, &var, &varInit, batchSize] (EnvironmentExternalBase &env) { // Generate initial value into temporary variable @@ -185,7 +185,7 @@ void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, // Fill value across all batches genVariableFill(varInitEnv, "_value", "$(value)", "id_syn", stride, - var.access.template getDims(), batchSize); + adaptor.getVarDims(var), batchSize); }); } } From 42f0552c40a70c7ecb586308984a613d571f9d67 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 1 Sep 2023 13:05:19 +0100 Subject: [PATCH 28/60] fixed some typos in custom update group merged --- src/genn/genn/code_generator/customUpdateGroupMerged.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 2ff12b9873..1d5041181b 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -62,7 +62,7 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", [this, batchSize, &cuEnv](const std::string&, VarAccess d) { - return getVarIndex(batchSize, clearDim(getArchetype().getDims(), d.getDims()), "$(id)"); + return getVarIndex(batchSize, clearDim(getArchetype().getDims(), d.getDims()), "$(id)"); }); // Create an environment which caches variable references in local variables if they are accessed @@ -189,7 +189,7 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", [this, batchSize, &cuEnv](const std::string&, VarAccess d) { - return getVarIndex(batchSize, clearDim(getArchetype().getDims(), d.getDims()), "$(id_syn)"); + return getVarIndex(batchSize, clearDim(getArchetype().getDims(), d.getDims()), "$(id_syn)"); }); // Create an environment which caches variable references in local variables if they are accessed @@ -197,7 +197,7 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back *this, *this, getTypeContext(), varEnv, backend.getDeviceVarPrefix(), "", "l", [this, batchSize, &varEnv](const std::string&, const Models::WUVarReference &v) { - return getVarRefIndex(batchSize, v.getVar().access.getDims(), "$(id_syn)"); + return getVarRefIndex(batchSize, v.getDims(), "$(id_syn)"); }); Transpiler::ErrorHandler errorHandler("Custom update '" + getArchetype().getName() + "' update code"); From e152bcfd48d626de2924fea7510e1d51fa4094f6 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 1 Sep 2023 13:05:42 +0100 Subject: [PATCH 29/60] extend VarAccess.get_custom_update_dims to do clearDim logic --- pygenn/src/genn.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pygenn/src/genn.cc b/pygenn/src/genn.cc index c385c9d6a8..806f3f9892 100644 --- a/pygenn/src/genn.cc +++ b/pygenn/src/genn.cc @@ -705,7 +705,10 @@ PYBIND11_MODULE(genn, m) .def("get_synapse_dims", [](const VarAccess &v) { return v.getDims(); }) .def("get_custom_update_dims", - [](const VarAccess &v) { return v.getDims(); }); + [](const VarAccess &v, VarAccessDim cuDims) + { + return clearDim(cuDims, v.getDims()); + }); //------------------------------------------------------------------------ // genn.Var From 68d6bc38e88d489355c34fc7db78171bef2c2994 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 1 Sep 2023 13:06:39 +0100 Subject: [PATCH 30/60] corrected getSynapseVarSize logic and changed to take SynapseGroup --- .../genn/code_generator/generateRunner.cc | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 4e9aaf3409..4100aae14d 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -43,25 +43,28 @@ size_t getNeuronVarSize(VarAccessDim varDims, size_t numElements, size_t batchSi return getNumVarCopies(varDims, batchSize, batched) * getNumNeuronVarElements(varDims, numElements) * delaySlots; } //-------------------------------------------------------------------------- -size_t getSynapseVarSize(VarAccessDim varDims, size_t numPre, size_t rowStride, size_t batchSize, - std::optional kernelSize = std::nullopt, bool batched = true) +size_t getSynapseVarSize(VarAccessDim varDims, const BackendBase &backend, const SynapseGroupInternal &sg, + size_t batchSize, bool batched = true) { - const size_t numCopies = getNumVarCopies(varDims, batchSize, batched); const bool pre = (varDims & VarAccessDim::PRE_NEURON); const bool post = (varDims & VarAccessDim::POST_NEURON); + const unsigned int numPre = sg.getSrcNeuronGroup()->getNumNeurons(); + const unsigned int numPost = sg.getTrgNeuronGroup()->getNumNeurons(); + const unsigned int rowStride = backend.getSynapticMatrixRowStride(sg); + const size_t numCopies = getNumVarCopies(varDims, batchSize, batched); if(pre && post) { - if(kernelSize) { - return kernelSize.value() * numCopies; + if(sg.getMatrixType() & SynapseMatrixWeight::KERNEL) { + return sg.getKernelSizeFlattened() * numCopies; } else { return numPre * rowStride * numCopies; } } - else if(varDims & VarAccessDim::PRE_NEURON) { + else if(pre) { return numPre * numCopies; } - else if(varDims & VarAccessDim::POST_NEURON) { - return rowStride * numCopies; + else if(post) { + return numPost * numCopies; } else { return numCopies; @@ -1212,11 +1215,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, runnerPushFunc, runnerPullFunc, model.getCustomWUUpdates(), mem, statePushPullFunctions, [batchSize, &backend](const CustomUpdateWUInternal &c, VarAccessDim varDims) { - const SynapseGroupInternal *sg = c.getSynapseGroup(); - return getSynapseVarSize(varDims, sg->getSrcNeuronGroup()->getNumNeurons(), - backend.getSynapticMatrixRowStride(*sg), batchSize, - (sg->getMatrixType() & SynapseMatrixWeight::KERNEL) ? std::make_optional(sg->getKernelSizeFlattened()) : std::nullopt, - c.getDims() & VarAccessDim::BATCH); + return getSynapseVarSize(varDims, backend, *c.getSynapseGroup(), + batchSize, c.getDims() & VarAccessDim::BATCH); }); allVarStreams << std::endl; @@ -1430,9 +1430,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const auto resolvedType = wuVar.type.resolve(modelMerged.getModel().getTypeContext()); if(individualWeights || kernelWeights) { const size_t size = getSynapseVarSize(wuVar.access.getDims(), - s.second.getSrcNeuronGroup()->getNumNeurons(), - backend.getSynapticMatrixRowStride(s.second), batchSize, - kernelWeights ? std::make_optional(s.second.getKernelSizeFlattened()) : std::nullopt); + backend, s.second, batchSize); genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerPushFunc, runnerPullFunc, resolvedType, wuVar.name + s.second.getName(), s.second.getWUVarLocation(wuVar.name), From 3b26adf77d92b7aba608550443f2d8e7fb28f229 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 1 Sep 2023 13:09:54 +0100 Subject: [PATCH 31/60] updated variable shape-getting logic in PyGeNN --- pygenn/genn_groups.py | 181 ++++++++++++++++++++++++++++-------------- 1 file changed, 121 insertions(+), 60 deletions(-) diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index 5b6447d19c..e6c2f3ca55 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -15,9 +15,46 @@ from . import neuron_models, types from .genn import (CustomUpdateWU, SynapseMatrixConnectivity, - SynapseMatrixWeight, VarAccessDim, VarLocation) + SynapseMatrixWeight, VarAccess, VarAccessDim, VarLocation) from .model_preprocessor import prepare_model, ExtraGlobalParameter, Variable +def _get_num_var_copies(var_dims, batch_size): + if (var_dims & VarAccessDim.BATCH): + return () if batch_size == 1 else (batch_size,) + else: + return () + +def _get_num_neuron_var_elements(var_dims, num_elements): + if (var_dims & VarAccessDim.NEURON): + return (num_elements,) + else: + return (1,) + +def _get_neuron_var_shape(var_dims, num_elements, batch_size, + num_delay_slots=1): + num_delay_slots = () if num_delay_slots == 1 else (num_delay_slots,) + return (_get_num_var_copies(var_dims, batch_size) + + num_delay_slots + + _get_num_neuron_var_elements(var_dims, num_elements)) + +def _get_synapse_var_shape(var_dims, sg, batch_size): + num_copies = _get_num_var_copies(var_dims, batch_size) + pre = (var_dims & VarAccessDim.PRE_NEURON) + post = (var_dims & VarAccessDim.POST_NEURON) + num_pre = sg.src.size + num_post = sg.trg.size + if pre and post: + if sg.matrix_type & SynapseMatrixWeight.KERNEL: + return num_copies + (np.product(sg.kernel_size),) + else: + # **YUCK** this isn't correct - only backend knows correct stride + return num_copies + (num_pre * sg.max_connections,) + elif pre: + return num_copies + (num_pre,) + elif post: + return num_copies + (num_post,) + else: + return num_copies + (1,) class GroupMixin(object): @@ -83,7 +120,7 @@ def push_extra_global_param_to_device(self, egp_name): egp_name -- string with the name of the variable """ self._push_extra_global_param_to_device(egp_name) - + def _assign_ext_ptr_array(self, var_name, var_size, var_type): """Assign a variable to an external numpy array @@ -168,12 +205,8 @@ def _pull_extra_global_param_from_device(self, egp_name, egp_dict=None): self._model._slm.pull_extra_global_param_from_device(self.name, egp_name, len(egp.values)) - def _load_vars(self, vars, size=None, var_dict=None, - get_location_fn=None, batched=True): - # If no size is specified, use standard size - if size is None: - size = self.size - + def _load_vars(self, vars, get_shape_fn, var_dict=None, + get_location_fn=None): # If no variable dictionary is specified, use standard one if var_dict is None: var_dict = self.vars @@ -190,23 +223,15 @@ def _load_vars(self, vars, size=None, var_dict=None, # If variable is located on host var_loc = get_location_fn(v.name) if var_loc & VarLocation.HOST: - # Determine how many copies of this variable are present - var_batched = (batched and not v.access & VarAccessDuplication.SHARED) - num_copies = self._model.batch_size if var_batched else 1 - - # Determine size of this variable - var_size = (1 if v.access & VarAccessDuplication.SHARED_NEURON - else size) - + # Determine shape of this variable + var_shape = get_shape_fn(v) + # Get view resolved_type = var_data.type.resolve(self._model.type_context) var_data._view = self._assign_ext_ptr_array( - v.name, var_size * num_copies, resolved_type) + v.name, np.prod(var_shape), resolved_type) - # If there is more than one copy, reshape view to 2D - if num_copies > 1: - var_data._view = np.reshape(var_data._view, - (num_copies, -1)) + var_data._view = np.reshape(var_data._view, var_shape) # If manual initialisation is required, copy over variables if var_data.init_required: @@ -327,7 +352,12 @@ def load(self, num_recording_timesteps): "spkQuePtr", types.Uint32) # Load neuron state variables - self._load_vars(self.neuron_model.get_vars()) + # **TODO** delay slots + self._load_vars( + self.neuron_model.get_vars(), + lambda v: _get_neuron_var_shape(v.access.get_neuron_dims(), + self.size, + self._model.batch_size)) # Load neuron extra global params self._load_egp() @@ -623,19 +653,17 @@ def load(self): # If variable is located on host var_loc = self.get_wu_var_location(v.name) if var_loc & VarLocation.HOST: - # Determine how many copies of this variable are present - num_copies = (1 if (v.access & VarAccessDuplication.SHARED) != 0 - else self._model.batch_size) + # Determine shape of this variable + var_shape = _get_synapse_var_shape( + v.access.get_synapse_dims(), + self, self._model.batch_size) + # Get view resolved_type = var_data.type.resolve(self._model.type_context) var_data._view = self._assign_ext_ptr_array( - v.name, self.weight_update_var_size * num_copies, - resolved_type) + v.name, np.prod(var_shape), resolved_type) - # If there is more than one copy, reshape view to 2D - if num_copies > 1: - var_data._view = np.reshape(var_data._view, - (num_copies, -1)) + var_data._view = np.reshape(var_data._view, var_shape) # Initialise variable if necessary self._init_wum_var(var_data, num_copies) @@ -648,21 +676,35 @@ def load(self): # If population's presynaptic weight update hasn't been # fused, load weight update model presynaptic variables + # **TODO** delay if not self._wu_pre_model_fused: - self._load_vars(self.wu_model.get_pre_vars(), self.src.size, - self.pre_vars, self.get_wu_pre_var_location) + self._load_vars( + self.wu_model.get_pre_vars(), + lambda v: _get_neuron_var_shape(v.access.get_neuron_dims(), + self.src.size, + self._model.batch_size), + self.pre_vars, self.get_wu_pre_var_location) # If population's postsynaptic weight update hasn't been # fused, load weight update model postsynaptic variables + # **TODO** delay if not self._wu_post_model_fused: - self._load_vars(self.wu_model.get_post_vars(), self.trg.size, - self.post_vars, self.get_wu_post_var_location) + self._load_vars( + self.wu_model.get_post_vars(), + lambda v: _get_neuron_var_shape(v.access.get_neuron_dims(), + self.trg.size, + self._model.batch_size), + self.post_vars, self.get_wu_post_var_location) # If this synapse group's postsynaptic model hasn't been fused if not self._ps_model_fused: # Load postsynaptic update model variables - self._load_vars(self.ps_model.get_vars(), self.trg.size, - self.psm_vars, self.get_ps_var_location) + self._load_vars( + self.ps_model.get_vars(), + lambda v, b: _get_neuron_var_shape(v.access.get_neuron_dims(), + self.trg.size, + self._model.batch_size), + self.psm_vars, self.get_ps_var_location) # If it's inSyn is accessible on the host if self.in_syn_location & VarLocation.HOST: @@ -796,7 +838,11 @@ def size(self): def load(self): # Load current source variables - self._load_vars(self.current_source_model.get_vars()) + self._load_vars( + self.current_source_model.get_vars(), + lambda v: _get_neuron_var_shape(v.access.get_neuron_dims(), + self.size, + self._model.batch_size)) # Load current source extra global parameters self._load_egp() @@ -824,8 +870,13 @@ def _init_group(self, model, var_space): self.custom_update_model, self, var_space) def load(self): - self._load_vars(self.custom_update_model.get_vars(), - size=self.size, batched=self._is_batched) + batch_size = (self._model.batch_size + if self._dims & VarAccessDim.BATCH + else 1) + self._load_vars( + self.custom_update_model.get_vars(), + lambda v: _get_neuron_var_shape(v.access.get_custom_update_dims(self._dims), + self.size, batch_size)) self._load_egp() def load_init_egps(self): @@ -856,6 +907,9 @@ def load(self): & SynapseMatrixConnectivity.PROCEDURAL) # Loop through state variables + batch_size = (self._model.batch_size + if self._dims & VarAccessDim.BATCH + else 1) for v in self.custom_update_model.get_vars(): # Get corresponding data from dictionary var_data = self.vars[v.name] @@ -863,20 +917,17 @@ def load(self): # If variable is located on host var_loc = self.get_var_location(v.name) if var_loc & VarLocation.HOST: - # Determine how many copies of this variable are present - var_batched = (self._is_batched and not v.access & VarAccessDuplication.SHARED) - num_copies = self._model.batch_size if var_batched else 1 - + # Determine shape of this variable + var_shape = _get_synapse_var_shape( + v.access.get_custom_update_dims(self._dims), + self, batch_size) + # Get view - size = self.synapse_group.weight_update_var_size * num_copies resolved_type = var_data.type.resolve(self._model.type_context) var_data._view = self._assign_ext_ptr_array( - v.name, size, resolved_type) + v.name, np.prod(var_shape), resolved_type) - # If there is more than one copy, reshape view to 2D - if num_copies > 1: - var_data._view = np.reshape(var_data._view, - (num_copies, -1)) + var_data._view = np.reshape(var_data._view, var_shape) # Initialise variable if necessary self.synapse_group._init_wum_var(var_data, num_copies) @@ -939,12 +990,18 @@ def load(self): # If variable is located on host var_loc = self.get_var_location(v.name) if var_loc & VarLocation.HOST: - # Get view - size = self.synapse_group.weight_update_var_size + # Determine shape of this variable + var_shape = _get_synapse_var_shape( + v.access.get_synapse_dims(), + self, 1) + resolved_type = var_data.type.resolve(self._model.type_context) var_data._view = self._assign_ext_ptr_array( - v.name, size, resolved_type) - + v.name, np.prod(var_shape), resolved_type) + + # **TODO** do this in assign_ext_ptr_array + var_data._view = np.reshape(var_data._view, var_shape) + # Initialise variable if necessary self.synapse_group._init_wum_var(var_data, 1) @@ -952,12 +1009,16 @@ def load(self): self._load_egp(var_data.extra_global_params, v.name) # Load pre and postsynaptic variables - self._load_vars(self.model.get_pre_vars(), - self.synapse_group.src.size, - self.pre_vars, self.get_pre_var_location, False) - self._load_vars(self.model.get_post_vars(), - self.synapse_group.trg.size, - self.post_vars, self.get_post_var_location, False) + self._load_vars( + self.model.get_pre_vars(), + lambda v: _get_neuron_var_shape(v.access.get_neuron_dims(), + self.synapse_group.src.size, 1), + self.pre_vars, self.get_pre_var_location) + self._load_vars( + self.model.get_post_vars(), + lambda v: _get_neuron_var_shape(v.access.get_neuron_dims(), + self.synapse_group.trg.size, 1), + self.post_vars, self.get_post_var_location) # Load custom update extra global parameters self._load_egp() From 23f14160bf8b87ca126f71ae655d5c258bf403b2 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 1 Sep 2023 15:00:32 +0100 Subject: [PATCH 32/60] fix bug in environment --- include/genn/genn/code_generator/environment.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 4e057c8fc6..f7fc6844f6 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -723,7 +723,12 @@ class VarCachePolicy //------------------------------------------------------------------------ bool shouldAlwaysCopy(G&, const Models::Base::Var &var) const { - return m_ShouldAlwaysCopy(var.name, var.access); + if(m_ShouldAlwaysCopy) { + return m_ShouldAlwaysCopy(var.name, var.access); + } + else { + return false; + } } std::string getReadIndex(G&, const Models::Base::Var &var) const From 6afbe5ebbb25dcc4fa8365a71a4b079ddf0b36d6 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 1 Sep 2023 15:32:28 +0100 Subject: [PATCH 33/60] incorrect flags --- .../code_generator/presynapticUpdateStrategySIMT.cc | 10 +++++----- .../genn/code_generator/synapseUpdateGroupMerged.cc | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc index 73de775c6b..6598e09847 100644 --- a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc +++ b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc @@ -96,7 +96,7 @@ void PreSpan::genUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMerg { CodeStream::Scope b(env.getStream()); - env.printLine("const unsigned int preInd = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::PRE_NEURON, "spike") + "];"); + env.printLine("const unsigned int preInd = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "spike") + "];"); const auto indexType = backend.getSynapseIndexType(sg); const auto indexTypeName = indexType.getName(); @@ -247,7 +247,7 @@ void PostSpan::genUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMer { CodeStream::Scope b(env.getStream()); const std::string index = "(r * " + std::to_string(backend.getKernelBlockSize(KernelPresynapticUpdate)) + ") + " + backend.getThreadID(); - env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::PRE_NEURON, index) + "];"); + env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, index) + "];"); env.printLine("$(_sh_spk" + eventSuffix + ")[" + backend.getThreadID() + "] = spk;"); if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { env.printLine("$(_sh_row_length)[" + backend.getThreadID() + "] = $(_row_length)[spk];"); @@ -459,7 +459,7 @@ void PreSpanProcedural::genUpdate(EnvironmentExternalBase &env, PresynapticUpdat // Create environment and add presynaptic index EnvironmentGroupMergedField synEnv(groupEnv, sg); synEnv.add(Type::Uint32.addConst(), "id_pre", "preInd", - {synEnv.addInitialiser("const unsigned int preInd = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::PRE_NEURON, "$(_spike)") + "];")}); + {synEnv.addInitialiser("const unsigned int preInd = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(_spike)") + "];")}); // **YUCK** add a hidden copy of num_post so we can overwrite deeper in here without losing access to original synEnv.add(Type::Uint32.addConst(), "_num_post", "$(num_post)"); @@ -639,7 +639,7 @@ void PostSpanBitmask::genUpdate(EnvironmentExternalBase &env, PresynapticUpdateG { CodeStream::Scope b(env.getStream()); const std::string index = "(r * " + std::to_string(backend.getKernelBlockSize(KernelPresynapticUpdate)) + ") + " + backend.getThreadID(); - env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::PRE_NEURON, index) + "];"); + env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, index) + "];"); env.printLine("$(_sh_spk" + eventSuffix + ")[" + backend.getThreadID() + "] = spk;"); } backend.genSharedMemBarrier(env.getStream()); @@ -873,7 +873,7 @@ void PostSpanToeplitz::genUpdate(EnvironmentExternalBase &env, PresynapticUpdate { CodeStream::Scope b(env.getStream()); const std::string index = "(r * " + std::to_string(backend.getKernelBlockSize(KernelPresynapticUpdate)) + ") + " + backend.getThreadID(); - env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::PRE_NEURON, index) + "];"); + env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, index) + "];"); env.printLine("$(_sh_spk" + eventSuffix + ")[" + backend.getThreadID() + "] = spk;"); } backend.genSharedMemBarrier(env.getStream()); diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 2974acbce4..0772e46ddf 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -53,8 +53,8 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa const std::string timeStr = sg.getTimeType().getName(); const std::string axonalDelayMs = Type::writeNumeric(dt * (double)(sg.getArchetype().getDelaySteps() + 1u), sg.getTimeType()); const bool preDelay = sg.getArchetype().getSrcNeuronGroup()->isDelayRequired(); - const std::string preSTIndex = sg.getPreVarIndex(preDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::PRE_NEURON, "$(id_pre)"); - const std::string prevPreSTIndex = sg.getPrePrevSpikeTimeIndex(preDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::PRE_NEURON, "$(id_pre)"); + const std::string preSTIndex = sg.getPreVarIndex(preDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id_pre)"); + const std::string prevPreSTIndex = sg.getPrePrevSpikeTimeIndex(preDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id_pre)"); synEnv.add(sg.getTimeType().addConst(), "st_pre", "stPre", {synEnv.addInitialiser("const " + timeStr + " stPre = " + axonalDelayMs + " + $(_src_st)[" + preSTIndex + "];")}); synEnv.add(sg.getTimeType().addConst(), "prev_st_pre", "prevSTPre", @@ -67,8 +67,8 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa // Calculate backprop delay to add to (somatic) spike times and substitute in postsynaptic spike times const std::string backPropDelayMs = Type::writeNumeric(dt * (double)(sg.getArchetype().getBackPropDelaySteps() + 1u), sg.getTimeType()); const bool postDelay = sg.getArchetype().getTrgNeuronGroup()->isDelayRequired(); - const std::string postSTIndex = sg.getPostVarIndex(postDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::POST_NEURON, "$(id_post)"); - const std::string prevPostSTIndex = sg.getPostPrevSpikeTimeIndex(postDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::POST_NEURON, "$(id_post)"); + const std::string postSTIndex = sg.getPostVarIndex(postDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id_post)"); + const std::string prevPostSTIndex = sg.getPostPrevSpikeTimeIndex(postDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id_post)"); synEnv.add(sg.getTimeType().addConst(), "st_post", "stPost", {synEnv.addInitialiser("const " + timeStr + " stPost = " + backPropDelayMs + " + $(_trg_st)[" + postSTIndex + "];")}); synEnv.add(sg.getTimeType().addConst(), "prev_st_post", "prevSTPost", From 5dffc5cf287bd9a296ddf7da56cfc0cf6e749a75 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 1 Sep 2023 15:58:34 +0100 Subject: [PATCH 34/60] PyBind11 wrapper for ``Var`` cannot implicitly convert ``NeuronVarAccess`` etc to ``VarAccess`` - make explicit --- pygenn/src/genn.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pygenn/src/genn.cc b/pygenn/src/genn.cc index 806f3f9892..ab588bc3db 100644 --- a/pygenn/src/genn.cc +++ b/pygenn/src/genn.cc @@ -696,10 +696,6 @@ PYBIND11_MODULE(genn, m) // genn.VarAccess //------------------------------------------------------------------------ pybind11::class_(m, "VarAccess") - .def(pybind11::init()) - .def(pybind11::init()) - .def(pybind11::init()) - .def("get_neuron_dims", [](const VarAccess &v) { return v.getDims(); }) .def("get_synapse_dims", @@ -714,9 +710,13 @@ PYBIND11_MODULE(genn, m) // genn.Var //------------------------------------------------------------------------ pybind11::class_(m, "Var") - .def(pybind11::init()) + .def(pybind11::init()) + .def(pybind11::init()) + .def(pybind11::init()) .def(pybind11::init()) - .def(pybind11::init()) + .def(pybind11::init()) + .def(pybind11::init()) + .def(pybind11::init()) .def(pybind11::init()) .def_readonly("name", &Models::Base::Var::name) .def_readonly("type", &Models::Base::Var::type) From 426d100f422d523fe9faf40572e7e23285decce7 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 1 Sep 2023 16:08:50 +0100 Subject: [PATCH 35/60] removed more calls to getDims on var rather than var reference and added similar wrapper for transpose --- include/genn/genn/models.h | 8 ++++-- src/genn/genn/customUpdate.cc | 4 +-- src/genn/genn/models.cc | 52 +++++++++++++++++------------------ 3 files changed, 34 insertions(+), 30 deletions(-) diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index bae2c03d35..3c341a7527 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -271,13 +271,16 @@ class GENN_EXPORT WUVarReference : public VarReferenceBase std::string getTargetName() const; //! Get dimensions of variable being referenced - VarAccessDim getDims() const; + VarAccessDim getDims() const{ return getVarDims(getVar()); } SynapseGroup *getSynapseGroup() const; SynapseGroup *getTransposeSynapseGroup() const; std::string getTransposeTargetName() const; + //! Get dimensions of transpose variable being referenced + VarAccessDim getTransposeDims() const{ return getVarDims(getTransposeVar()); } + //! If this reference points to another custom update, return pointer to it /*! This is used to detect circular dependencies */ CustomUpdateWU *getReferencedCustomUpdate() const; @@ -321,13 +324,14 @@ class GENN_EXPORT WUVarReference : public VarReferenceBase //------------------------------------------------------------------------ SynapseGroupInternal *getSynapseGroupInternal() const; SynapseGroupInternal *getTransposeSynapseGroupInternal() const; + VarAccessDim getVarDims(const Models::Base::Var &var) const; WUVarReference(size_t varIndex, const Models::Base::VarVec &varVec, const DetailType &detail); WUVarReference(size_t varIndex, const Models::Base::VarVec &varVec, size_t transposeVarIndex, const Models::Base::VarVec &transposeVarVec, const DetailType &detail); - + //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index 64bb15b6af..41ee1456b7 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -298,8 +298,8 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWU::getHashDigest() const // Update hash with whether variable references require transpose Utils::updateHash((v.second.getTransposeSynapseGroup() == nullptr), hash); - // Update hash with access mode of target variable dimensions as this effects indexing code - Utils::updateHash(v.second.getVar().access.getDims(), hash); + // Update hash with dimensionality of target variable dimensions as this effects indexing code + Utils::updateHash(v.second.getDims(), hash); } return hash.get_digest(); diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index f079fabb49..90bd392553 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -188,30 +188,6 @@ std::string WUVarReference::getTargetName() const m_Detail); } //---------------------------------------------------------------------------- -VarAccessDim WUVarReference::getDims() const -{ - const auto &varAccess = getVar().access; - return std::visit( - Utils::Overload{ - // If reference is to a custom update variable, - // remove dimensions from those of update - [&varAccess](const CURef &ref) - { - return clearDim(ref.group->getDims(), - varAccess.getDims()); - }, - // Otherwise, if reference is to the synaptic variables of a custom connectivity update, - // remove BATCH dimension as these are never batched - [&varAccess](const CCURef&) - { - return clearDim(varAccess.getDims(), - VarAccessDim::BATCH); - }, - // Otherwise, use dimensionality directly - [&varAccess](const WURef&){ return varAccess.getDims(); }}, - m_Detail); -} -//---------------------------------------------------------------------------- SynapseGroup *WUVarReference::getSynapseGroup() const { return getSynapseGroupInternal(); @@ -305,6 +281,30 @@ SynapseGroupInternal *WUVarReference::getTransposeSynapseGroupInternal() const m_Detail); } //------------------------------------------------------------------------ +VarAccessDim WUVarReference::getVarDims(const Models::Base::Var &var) const +{ + const auto &varAccess = var.access; + return std::visit( + Utils::Overload{ + // If reference is to a custom update variable, + // remove dimensions from those of update + [&varAccess](const CURef &ref) + { + return clearDim(ref.group->getDims(), + varAccess.getDims()); + }, + // Otherwise, if reference is to the synaptic variables of a custom connectivity update, + // remove BATCH dimension as these are never batched + [&varAccess](const CCURef&) + { + return clearDim(varAccess.getDims(), + VarAccessDim::BATCH); + }, + // Otherwise, use dimensionality directly + [&varAccess](const WURef&){ return varAccess.getDims(); }}, + m_Detail); +} +//------------------------------------------------------------------------ WUVarReference::WUVarReference(size_t varIndex, const Models::Base::VarVec &varVec, const DetailType &detail) : VarReferenceBase(varIndex, varVec), m_TransposeVarIndex(std::nullopt), @@ -354,8 +354,8 @@ WUVarReference::WUVarReference(size_t varIndex, const Models::Base::VarVec &varV } // Check duplicatedness of variables - if((getVar().access.getDims() & VarAccessDim::BATCH) - != (getTransposeVar().access.getDims() & VarAccessDim::BATCH)) + if((getDims() & VarAccessDim::BATCH) + != (getTransposeDims() & VarAccessDim::BATCH)) { throw std::runtime_error("Transpose updates can only be performed on similarly batched variables"); } From 79cac2f33af34123855f592d32a47a7fc08014ed Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 1 Sep 2023 16:30:55 +0100 Subject: [PATCH 36/60] don't generate code for empty custom updates --- src/genn/genn/code_generator/modelSpecMerged.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/genn/genn/code_generator/modelSpecMerged.cc b/src/genn/genn/code_generator/modelSpecMerged.cc index 660e92b718..4a4dfc3cb4 100644 --- a/src/genn/genn/code_generator/modelSpecMerged.cc +++ b/src/genn/genn/code_generator/modelSpecMerged.cc @@ -34,11 +34,11 @@ using namespace GeNN::CodeGenerator; &SynapseGroupInternal::getWUHashDigest); createMergedGroups(getModel().getCustomUpdates(), m_MergedCustomUpdateGroups, - [](const CustomUpdateInternal&) { return true; }, + [](const CustomUpdateInternal &cg) { return !Utils::areTokensEmpty(cg.getUpdateCodeTokens()); }, &CustomUpdateInternal::getHashDigest); createMergedGroups(getModel().getCustomWUUpdates(), m_MergedCustomUpdateWUGroups, - [](const CustomUpdateWUInternal &cg) { return !cg.isTransposeOperation(); }, + [](const CustomUpdateWUInternal &cg) { return !Utils::areTokensEmpty(cg.getUpdateCodeTokens()) && !cg.isTransposeOperation(); }, &CustomUpdateWUInternal::getHashDigest); createMergedGroups(getModel().getCustomWUUpdates(), m_MergedCustomUpdateTransposeWUGroups, From 1a13a3ad6b4d348a6326efbed016b407e374b030 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 1 Sep 2023 16:31:12 +0100 Subject: [PATCH 37/60] infer batchedness from shape in _init_wum_var --- pygenn/genn_groups.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index e6c2f3ca55..6cc5962af9 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -666,7 +666,7 @@ def load(self): var_data._view = np.reshape(var_data._view, var_shape) # Initialise variable if necessary - self._init_wum_var(var_data, num_copies) + self._init_wum_var(var_data) else: assert not var_data.init_required var_data._view = None @@ -701,9 +701,9 @@ def load(self): # Load postsynaptic update model variables self._load_vars( self.ps_model.get_vars(), - lambda v, b: _get_neuron_var_shape(v.access.get_neuron_dims(), - self.trg.size, - self._model.batch_size), + lambda v: _get_neuron_var_shape(v.access.get_neuron_dims(), + self.trg.size, + self._model.batch_size), self.psm_vars, self.get_ps_var_location) # If it's inSyn is accessible on the host @@ -781,7 +781,7 @@ def _get_view_values(self, var_view): else: raise Exception("Matrix format not supported") - def _init_wum_var(self, var_data, num_copies): + def _init_wum_var(self, var_data): # If initialisation is required if var_data.init_required: # If connectivity is dense, @@ -792,7 +792,7 @@ def _init_wum_var(self, var_data, num_copies): var_data._view[:] = var_data.values elif (self.matrix_type & SynapseMatrixConnectivity.SPARSE): # Sort variable to match GeNN order - if num_copies == 1: + if len(var_data.shape) == 1: sorted_var = var_data.values[self.synapse_order] else: sorted_var = var_data.values[:,self.synapse_order] @@ -806,7 +806,7 @@ def _init_wum_var(self, var_data, num_copies): syn = 0 for i, r in zip(row_start_idx, self.row_lengths): # Copy row from non-padded indices into correct location - if num_copies == 1: + if len(var_data.shape) == 1: var_data._view[i:i + r] = sorted_var[syn:syn + r] else: var_data._view[:,i:i + r] = sorted_var[:,syn:syn + r] @@ -920,7 +920,7 @@ def load(self): # Determine shape of this variable var_shape = _get_synapse_var_shape( v.access.get_custom_update_dims(self._dims), - self, batch_size) + self.synapse_group, batch_size) # Get view resolved_type = var_data.type.resolve(self._model.type_context) @@ -930,7 +930,7 @@ def load(self): var_data._view = np.reshape(var_data._view, var_shape) # Initialise variable if necessary - self.synapse_group._init_wum_var(var_data, num_copies) + self.synapse_group._init_wum_var(var_data) # Load any var initialisation egps associated with this variable self._load_egp(var_data.extra_global_params, v.name) @@ -1003,7 +1003,7 @@ def load(self): var_data._view = np.reshape(var_data._view, var_shape) # Initialise variable if necessary - self.synapse_group._init_wum_var(var_data, 1) + self.synapse_group._init_wum_var(var_data) # Load any var initialisation egps associated with this variable self._load_egp(var_data.extra_global_params, v.name) From 8efb89eff23aa201312d155559008dd8b184bea9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 1 Sep 2023 16:39:12 +0100 Subject: [PATCH 38/60] fixed dimension flags being used backwards --- src/genn/genn/code_generator/backendSIMT.cc | 38 ++++++++++----------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index fb73b6309d..cb06968be0 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -210,7 +210,7 @@ size_t BackendSIMT::getPaddedNumCustomUpdateThreads(const CustomUpdateInternal & if (cg.isNeuronReduction()) { return padKernelSize(32 * numCopies, KernelCustomUpdate); } - else if (!(cg.getDims() & VarAccessDim::NEURON)) { + else if (cg.getDims() & VarAccessDim::NEURON) { return numCopies * padKernelSize(cg.getSize(), KernelCustomUpdate); } else { @@ -1042,25 +1042,8 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, ModelSpecMerge } } } - // Otherwise, if this update isn't per-neuron + // Otherwise, if this update is per-neuron else if (cg.getArchetype().getDims() & VarAccessDim::NEURON) { - // Use local ID for batch and always use zero for ID - groupEnv.add(Type::Uint32.addConst(), "batch", "$(_id)"); - groupEnv.add(Type::Uint32.addConst(), "id", "0"); - - groupEnv.getStream() << "// only do this for existing neurons" << std::endl; - groupEnv.getStream() << "if(" << groupEnv["batch"] << " < " << ((cg.getArchetype().getDims() & VarAccessDim::BATCH) ? batchSize : 1) << ")"; - { - CodeStream::Scope b(groupEnv.getStream()); - EnvironmentGroupMergedField batchEnv(groupEnv, cg); - buildStandardEnvironment(batchEnv, batchSize); - - cg.generateCustomUpdate(*this, batchEnv, batchSize, - [](auto&, auto&){}); - } - } - // Otherwise - else { if((cg.getArchetype().getDims() & VarAccessDim::BATCH) && (batchSize > 1)) { // Split ID into intra-batch ID and batch // **TODO** fast-divide style optimisations here @@ -1089,6 +1072,23 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, ModelSpecMerge [](auto&, auto&){}); } } + // Otherwise + else { + // Use local ID for batch and always use zero for ID + groupEnv.add(Type::Uint32.addConst(), "batch", "$(_id)"); + groupEnv.add(Type::Uint32.addConst(), "id", "0"); + + groupEnv.getStream() << "// only do this for existing neurons" << std::endl; + groupEnv.getStream() << "if(" << groupEnv["batch"] << " < " << ((cg.getArchetype().getDims() & VarAccessDim::BATCH) ? batchSize : 1) << ")"; + { + CodeStream::Scope b(groupEnv.getStream()); + EnvironmentGroupMergedField batchEnv(groupEnv, cg); + buildStandardEnvironment(batchEnv, batchSize); + + cg.generateCustomUpdate(*this, batchEnv, batchSize, + [](auto&, auto&){}); + } + } }); } //-------------------------------------------------------------------------- From 600df8b210609473390260dae95c6f3f9aee5964 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 1 Sep 2023 16:53:00 +0100 Subject: [PATCH 39/60] correctly resolve custom update var accesses in reductions --- include/genn/backends/single_threaded_cpu/backend.h | 6 ++++-- include/genn/genn/code_generator/backendBase.h | 4 +++- src/genn/backends/single_threaded_cpu/backend.cc | 4 ++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index e7d4bb585e..93fd13577f 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -246,7 +246,7 @@ class BACKEND_EXPORT Backend : public BackendBase /*! Because reduction operations are unnecessary in unbatched single-threaded CPU models so there's no need to actually reduce */ void genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateWUGroupMergedBase &cg, const std::string &idxName) const; - template + template void genWriteBackReductions(EnvironmentExternalBase &env, G &cg, const std::string &idxName, R getVarRefIndexFn) const { const auto *cm = cg.getArchetype().getCustomUpdateModel(); @@ -254,7 +254,9 @@ class BACKEND_EXPORT Backend : public BackendBase // If variable is a reduction target, copy value from register straight back into global memory if(v.access & VarAccessModeAttribute::REDUCE) { const std::string idx = env.getName(idxName); - env.getStream() << "group->" << v.name << "[" << cg.getVarIndex(1, v.access.template getDims(), idx) << "] = " << env[v.name] << ";" << std::endl; + const VarAccessDim varAccessDim = clearDim(cg.getArchetype().getDims(), + v.access.template getDims()); + env.getStream() << "group->" << v.name << "[" << cg.getVarIndex(1, varAccessDim, idx) << "] = " << env[v.name] << ";" << std::endl; } } diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index cd1f9aa390..14edf81e8f 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -574,8 +574,10 @@ class GENN_EXPORT BackendBase if (v.access & VarAccessModeAttribute::REDUCE) { const auto resolvedType = v.type.resolve(cg.getTypeContext()); os << resolvedType.getName() << " _lr" << v.name << " = " << getReductionInitialValue(v.access, resolvedType) << ";" << std::endl; + const VarAccessDim varAccessDim = clearDim(cg.getArchetype().getDims(), + v.access.template getDims()); reductionTargets.push_back({v.name, resolvedType, v.access, - cg.getVarIndex(batchSize, v.access.template getDims(), idx)}); + cg.getVarIndex(batchSize, varAccessDim, idx)}); } } diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index d4a14e3546..826327071b 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -2013,7 +2013,7 @@ void Backend::genEmitSpike(EnvironmentExternalBase &env, NeuronUpdateGroupMerged //-------------------------------------------------------------------------- void Backend::genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateGroupMerged &cg, const std::string &idxName) const { - genWriteBackReductions( + genWriteBackReductions( env, cg, idxName, [&cg](const Models::VarReference &varRef, const std::string &index) { @@ -2024,7 +2024,7 @@ void Backend::genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateG //-------------------------------------------------------------------------- void Backend::genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateWUGroupMergedBase &cg, const std::string &idxName) const { - genWriteBackReductions( + genWriteBackReductions( env, cg, idxName, [&cg](const Models::WUVarReference &varRef, const std::string &index) { From 62f0e0527701e6ef5576897bb9590f54badcd374 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 1 Sep 2023 16:56:28 +0100 Subject: [PATCH 40/60] fixed some variable access typos in test_custom_update --- tests/features/test_custom_update.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/features/test_custom_update.py b/tests/features/test_custom_update.py index 504620987f..2e6ecba327 100644 --- a/tests/features/test_custom_update.py +++ b/tests/features/test_custom_update.py @@ -51,7 +51,7 @@ def test_custom_update(backend, precision, batch_size): custom_update_model = create_custom_update_model( "custom_update", - var_name_types=[("X", "scalar", CustomUpdateVarAccess.READ_ONLY_DUPLICATE)], + var_name_types=[("X", "scalar", CustomUpdateVarAccess.READ_ONLY)], var_refs=[("R", "scalar")]) set_time_custom_update_model = create_custom_update_model( @@ -272,8 +272,8 @@ def test_custom_update_transpose(backend, precision, batch_size): def test_custom_update_neuron_reduce(backend, precision, batch_size): reduction_neuron_model = create_neuron_model( "reduction_neuron", - var_name_types=[("X", "scalar", CustomUpdateVarAccess.READ_ONLY_DUPLICATE), - ("Y", "scalar", CustomUpdateVarAccess.READ_ONLY_DUPLICATE)]) + var_name_types=[("X", "scalar", NeuronVarAccess.READ_ONLY_DUPLICATE), + ("Y", "scalar", NeuronVarAccess.READ_ONLY_DUPLICATE)]) softmax_1_custom_update_model = create_custom_update_model( "softmax_1", From 9fce2152185e857c6aed579d750df8fd619a45be Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 1 Sep 2023 16:56:38 +0100 Subject: [PATCH 41/60] fixed typos in _init_wum_var --- pygenn/genn_groups.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index 6cc5962af9..c17b494f07 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -789,10 +789,10 @@ def _init_wum_var(self, var_data): # **NOTE** we assume order is row-major if ((self.matrix_type & SynapseMatrixConnectivity.DENSE) or (self.matrix_type & SynapseMatrixWeight.KERNEL)): - var_data._view[:] = var_data.values + var_data.view[:] = var_data.values elif (self.matrix_type & SynapseMatrixConnectivity.SPARSE): # Sort variable to match GeNN order - if len(var_data.shape) == 1: + if len(var_data.view.shape) == 1: sorted_var = var_data.values[self.synapse_order] else: sorted_var = var_data.values[:,self.synapse_order] @@ -806,10 +806,10 @@ def _init_wum_var(self, var_data): syn = 0 for i, r in zip(row_start_idx, self.row_lengths): # Copy row from non-padded indices into correct location - if len(var_data.shape) == 1: - var_data._view[i:i + r] = sorted_var[syn:syn + r] + if len(var_data.view.shape) == 1: + var_data.view[i:i + r] = sorted_var[syn:syn + r] else: - var_data._view[:,i:i + r] = sorted_var[:,syn:syn + r] + var_data.view[:,i:i + r] = sorted_var[:,syn:syn + r] syn += r else: raise Exception("Matrix format not supported") From 995c7b839a45d5dbdce906aba14b4da30a27e60c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 1 Sep 2023 17:02:15 +0100 Subject: [PATCH 42/60] typo in loading of custom connectivity update groups --- pygenn/genn_groups.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index c17b494f07..0a62d04b28 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -993,7 +993,7 @@ def load(self): # Determine shape of this variable var_shape = _get_synapse_var_shape( v.access.get_synapse_dims(), - self, 1) + self.synapse_group, 1) resolved_type = var_data.type.resolve(self._model.type_context) var_data._view = self._assign_ext_ptr_array( From 755d192a59b362c5dc633236af33970b239967b2 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 5 Sep 2023 11:40:38 +0100 Subject: [PATCH 43/60] * switched models to use vectors of correct enum type * derived Var classes with correct defaults * restored inline helper classes for extracting bits of XXXVarAccess --- include/genn/genn/currentSourceModels.h | 9 ++ .../genn/customConnectivityUpdateModels.h | 13 ++- include/genn/genn/customUpdateModels.h | 9 ++ include/genn/genn/models.h | 64 ++++++----- include/genn/genn/neuronModels.h | 9 ++ include/genn/genn/postsynapticModels.h | 9 ++ include/genn/genn/varAccess.h | 100 +++++------------- include/genn/genn/weightUpdateModels.h | 13 ++- 8 files changed, 122 insertions(+), 104 deletions(-) diff --git a/include/genn/genn/currentSourceModels.h b/include/genn/genn/currentSourceModels.h index 1a043b0d6d..863c390db4 100644 --- a/include/genn/genn/currentSourceModels.h +++ b/include/genn/genn/currentSourceModels.h @@ -32,12 +32,21 @@ class GENN_EXPORT Base : public Models::Base //! Gets the code that defines current injected each timestep virtual std::string getInjectionCode() const{ return ""; } + //! Gets model variables + virtual std::vector getVars() const{ return {}; } + //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- //! Update hash from model boost::uuids::detail::sha1::digest_type getHashDigest() const; + //! Find the index of a named variable + size_t getVarIndex(const std::string &varName) const + { + return getNamedVecIndex(varName, getVars()); + } + //! Validate names of parameters etc void validate(const std::unordered_map ¶mValues, const std::unordered_map &varValues, diff --git a/include/genn/genn/customConnectivityUpdateModels.h b/include/genn/genn/customConnectivityUpdateModels.h index 9d7ddeb15a..d2f5f4d2ed 100644 --- a/include/genn/genn/customConnectivityUpdateModels.h +++ b/include/genn/genn/customConnectivityUpdateModels.h @@ -31,11 +31,14 @@ class GENN_EXPORT Base : public Models::Base //---------------------------------------------------------------------------- //! Gets names and types (as strings) of state variables that are common //! across all synapses coming from the same presynaptic neuron - virtual VarVec getPreVars() const { return {}; } + virtual std::vector getPreVars() const { return {}; } //! Gets names and types (as strings) of state variables that are common //! across all synapses going to the same postsynaptic neuron - virtual VarVec getPostVars() const { return {}; } + virtual std::vector getPostVars() const { return {}; } + + //! Gets model variables + virtual std::vector getVars() const{ return {}; } //! Gets names and types (as strings) of synapse variable references virtual VarRefVec getVarRefs() const { return {}; } @@ -56,6 +59,12 @@ class GENN_EXPORT Base : public Models::Base // Public API //---------------------------------------------------------------------------- //! Find the index of a named presynaptic variable + size_t getVarIndex(const std::string &varName) const + { + return getNamedVecIndex(varName, getVars()); + + } + //! Find the index of a named presynaptic variable size_t getPreVarIndex(const std::string &varName) const { return getNamedVecIndex(varName, getPreVars()); diff --git a/include/genn/genn/customUpdateModels.h b/include/genn/genn/customUpdateModels.h index ad40ab65dd..8b26b46caa 100644 --- a/include/genn/genn/customUpdateModels.h +++ b/include/genn/genn/customUpdateModels.h @@ -24,6 +24,9 @@ class GENN_EXPORT Base : public Models::Base //---------------------------------------------------------------------------- // Declared virtuals //---------------------------------------------------------------------------- + //! Gets model variables + virtual std::vector getVars() const{ return {}; } + //! Gets names and typesn of model variable references virtual VarRefVec getVarRefs() const{ return {}; } @@ -39,6 +42,12 @@ class GENN_EXPORT Base : public Models::Base //! Update hash from model boost::uuids::detail::sha1::digest_type getHashDigest() const; + //! Find the index of a named variable + size_t getVarIndex(const std::string &varName) const + { + return getNamedVecIndex(varName, getVars()); + } + //! Validate names of parameters etc void validate(const std::unordered_map ¶mValues, const std::unordered_map &varValues, diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 3c341a7527..1b2fd0cc90 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -52,30 +52,54 @@ class GENN_EXPORT Base : public Snippet::Base /*! Explicit constructors required as although, through the wonders of C++ aggregate initialization, access would default to VarAccess::READ_WRITE if not specified, this results in a -Wmissing-field-initializers warning on GCC and Clang*/ - struct GENN_EXPORT Var + template + struct VarBase { - Var(const std::string &n, const Type::ResolvedType &t) - : name(n), type(t) - {} - Var(const std::string &n, const std::string &t) - : name(n), type(t) - {} - - Var(const std::string &n, const Type::ResolvedType &t, VarAccess a) + VarBase(const std::string &n, const Type::ResolvedType &t, A a) : name(n), type(t), access(a) {} - Var(const std::string &n, const std::string &t, VarAccess a) + VarBase(const std::string &n, const std::string &t, A a) : name(n), type(t), access(a) {} - bool operator == (const Var &other) const + bool operator == (const VarBase &other) const { return (std::tie(name, type, access) == std::tie(other.name, other.type, other.access)); } std::string name; Type::UnresolvedType type; - VarAccess access; + A access; + }; + + struct NeuronVar : public VarBase + { + NeuronVar(const std::string &n, const Type::ResolvedType &t) + : VarBase(n, t, NeuronVarAccess::READ_WRITE) + {} + NeuronVar(const std::string &n, const std::string &t) + : VarBase(n, t, NeuronVarAccess::READ_WRITE) + {} + }; + + struct SynapseVar : public VarBase + { + SynapseVar(const std::string &n, const Type::ResolvedType &t) + : VarBase(n, t, SynapseVarAccess::READ_WRITE) + {} + SynapseVar(const std::string &n, const std::string &t) + : VarBase(n, t, SynapseVarAccess::READ_WRITE) + {} + }; + + struct CustomUpdateVar : public VarBase + { + CustomUpdateVar(const std::string &n, const Type::ResolvedType &t) + : VarBase(n, t, CustomUpdateVarAccess::READ_WRITE) + {} + CustomUpdateVar(const std::string &n, const std::string &t) + : VarBase(n, t, CustomUpdateVarAccess::READ_WRITE) + {} }; struct GENN_EXPORT VarRef @@ -113,25 +137,9 @@ class GENN_EXPORT Base : public Snippet::Base //---------------------------------------------------------------------------- // Typedefines //---------------------------------------------------------------------------- - typedef std::vector VarVec; typedef std::vector VarRefVec; typedef std::vector EGPRefVec; - //---------------------------------------------------------------------------- - // Declared virtuals - //------------------------------------------------------------------------ - //! Gets model variables - virtual VarVec getVars() const{ return {}; } - - //------------------------------------------------------------------------ - // Public methods - //------------------------------------------------------------------------ - //! Find the index of a named variable - size_t getVarIndex(const std::string &varName) const - { - return getNamedVecIndex(varName, getVars()); - } - protected: //------------------------------------------------------------------------ // Protected methods diff --git a/include/genn/genn/neuronModels.h b/include/genn/genn/neuronModels.h index a2bd9b6b52..8b1ec5bc10 100644 --- a/include/genn/genn/neuronModels.h +++ b/include/genn/genn/neuronModels.h @@ -34,6 +34,9 @@ class GENN_EXPORT Base : public Models::Base //---------------------------------------------------------------------------- // Declared virtuals //---------------------------------------------------------------------------- + //! Gets model variables + virtual std::vector getVars() const{ return {}; } + //! Gets the code that defines the execution of one timestep of integration of the neuron model. /*! The code will refer to $(NN) for the value of the variable with name "NN". It needs to refer to the predefined variable "ISYN", i.e. contain $(ISYN), if it is to receive input. */ @@ -59,6 +62,12 @@ class GENN_EXPORT Base : public Models::Base //! Update hash from model boost::uuids::detail::sha1::digest_type getHashDigest() const; + //! Find the index of a named variable + size_t getVarIndex(const std::string &varName) const + { + return getNamedVecIndex(varName, getVars()); + } + //! Validate names of parameters etc void validate(const std::unordered_map ¶mValues, const std::unordered_map &varValues, diff --git a/include/genn/genn/postsynapticModels.h b/include/genn/genn/postsynapticModels.h index 264ed5aed9..abbac6d73b 100644 --- a/include/genn/genn/postsynapticModels.h +++ b/include/genn/genn/postsynapticModels.h @@ -25,6 +25,9 @@ class GENN_EXPORT Base : public Models::Base //---------------------------------------------------------------------------- // Declared virtuals //---------------------------------------------------------------------------- + //! Gets model variables + virtual std::vector getVars() const{ return {}; } + virtual std::string getDecayCode() const{ return ""; } virtual std::string getApplyInputCode() const{ return ""; } @@ -34,6 +37,12 @@ class GENN_EXPORT Base : public Models::Base //! Update hash from model boost::uuids::detail::sha1::digest_type getHashDigest() const; + //! Find the index of a named variable + size_t getVarIndex(const std::string &varName) const + { + return getNamedVecIndex(varName, getVars()); + } + //! Validate names of parameters etc void validate(const std::unordered_map ¶mValues, const std::unordered_map &varValues, diff --git a/include/genn/genn/varAccess.h b/include/genn/genn/varAccess.h index 5a2e58b4fa..fa27989bff 100644 --- a/include/genn/genn/varAccess.h +++ b/include/genn/genn/varAccess.h @@ -116,77 +116,33 @@ inline VarAccessDim clearDim(VarAccessDim a, VarAccessDim b) return static_cast(static_cast(a) & ~static_cast(b)); } -//---------------------------------------------------------------------------- -// VarAccess -//---------------------------------------------------------------------------- -//! Wrapper class encapsulating -class VarAccess -{ -public: - VarAccess() - {} - VarAccess(NeuronVarAccess n) : m_Access{n} - {} - VarAccess(SynapseVarAccess s) : m_Access{s} - {} - VarAccess(CustomUpdateVarAccess c) : m_Access{c} - {} - - //------------------------------------------------------------------------ - // Public API - //------------------------------------------------------------------------ - template - VarAccessDim getDims() const - { - // Extract value - const unsigned int val = std::visit( - Utils::Overload{ - // If access is set to default, use READ_WRITE mode of typed var access e.g. NeuronVarAcccess::READ_WRITE - [](std::monostate) { return static_cast(V::READ_WRITE); }, - // Otherwise, if stored type matches template type, use value - [](V v) { return static_cast(v); }, - // Otherwise, give error - [](auto)->unsigned int { throw std::runtime_error("Invalid var access type"); }}, - m_Access); - - // Mask out dimension bits and cast to enum - return static_cast(val & ~0x1F); - } - - template - bool isValid() const - { - return std::visit( - Utils::Overload{ - [](std::monostate) { return true; }, - [](V) { return true; }, - [](auto) { return false; }}, - m_Access); - } - - void updateHash(boost::uuids::detail::sha1 &hash) const - { - Utils::updateHash(m_Access, hash); - } - - //------------------------------------------------------------------------ - // Operators - //------------------------------------------------------------------------ - operator VarAccessMode() const - { - // If access is set to default, access mode is always read-write otherwise mask out and cast access mode, bits - return std::visit( - Utils::Overload{ - [](std::monostate) { return VarAccessMode::READ_WRITE; }, - [](auto v) { return static_cast(static_cast(v) & 0x1F); }}, - m_Access); - } - -private: - //------------------------------------------------------------------------ - // Members - //------------------------------------------------------------------------ - std::variant m_Access; -}; +inline VarAccessDim getAccessDim(NeuronVarAccess v) +{ + return static_cast(static_cast(v) & ~0x1F); +} +inline VarAccessDim getAccessDim(SynapseVarAccess v) +{ + return static_cast(static_cast(v) & ~0x1F); +} + +inline VarAccessDim getAccessDim(CustomUpdateVarAccess v, VarAccessDim popDims) +{ + return clearDim(popDims, static_cast(static_cast(v) & ~0x1F)); +} + +inline VarAccessMode getVarAccessMode(NeuronVarAccess v) +{ + return static_cast(static_cast(v) & 0x1F); +} + +inline VarAccessMode getVarAccessMode(SynapseVarAccess v) +{ + return static_cast(static_cast(v) & 0x1F); +} + +inline VarAccessMode getVarAccessMode(CustomUpdateVarAccess v) +{ + return static_cast(static_cast(v) & 0x1F); +} } // namespace GeNN diff --git a/include/genn/genn/weightUpdateModels.h b/include/genn/genn/weightUpdateModels.h index 8837011e67..038a74b48e 100644 --- a/include/genn/genn/weightUpdateModels.h +++ b/include/genn/genn/weightUpdateModels.h @@ -71,17 +71,26 @@ class GENN_EXPORT Base : public Models::Base and synapse variables are not accesible from within this code */ virtual std::string getPostDynamicsCode() const{ return ""; } + //! Gets model variables + virtual std::vector getVars() const{ return {}; } + //! Gets names and types (as strings) of state variables that are common //! across all synapses coming from the same presynaptic neuron - virtual VarVec getPreVars() const{ return {}; } + virtual std::vector getPreVars() const{ return {}; } //! Gets names and types (as strings) of state variables that are common //! across all synapses going to the same postsynaptic neuron - virtual VarVec getPostVars() const{ return {}; } + virtual std::vector getPostVars() const{ return {}; } //------------------------------------------------------------------------ // Public methods //------------------------------------------------------------------------ + //! Find the index of a named variable + size_t getVarIndex(const std::string &varName) const + { + return getNamedVecIndex(varName, getVars()); + } + //! Find the index of a named presynaptic variable size_t getPreVarIndex(const std::string &varName) const { From d87a3af15bc2e34fa231ba0dfb7893cc14543425 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 5 Sep 2023 14:20:32 +0100 Subject: [PATCH 44/60] updated adapters --- include/genn/genn/currentSourceInternal.h | 4 ++-- .../genn/genn/customConnectivityUpdateInternal.h | 12 ++++++------ include/genn/genn/customUpdate.h | 6 +++--- include/genn/genn/neuronGroupInternal.h | 4 ++-- include/genn/genn/synapseGroupInternal.h | 16 ++++++++-------- 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/include/genn/genn/currentSourceInternal.h b/include/genn/genn/currentSourceInternal.h index b6c7277500..b0aa62b380 100644 --- a/include/genn/genn/currentSourceInternal.h +++ b/include/genn/genn/currentSourceInternal.h @@ -46,7 +46,7 @@ class CurrentSourceVarAdapter //---------------------------------------------------------------------------- VarLocation getLoc(const std::string &varName) const{ return m_CS.getVarLocation(varName); } - Models::Base::VarVec getDefs() const{ return m_CS.getCurrentSourceModel()->getVars(); } + std::vector getDefs() const{ return m_CS.getCurrentSourceModel()->getVars(); } const std::unordered_map &getInitialisers() const{ return m_CS.getVarInitialisers(); } @@ -54,7 +54,7 @@ class CurrentSourceVarAdapter const std::string &getNameSuffix() const{ return m_CS.getName(); } - VarAccessDim getVarDims(const Models::Base::Var &var) const{ return var.access.getDims(); } + VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getVarDims(var.access); } private: //---------------------------------------------------------------------------- diff --git a/include/genn/genn/customConnectivityUpdateInternal.h b/include/genn/genn/customConnectivityUpdateInternal.h index da6ebb2e5d..9d680ad1b5 100644 --- a/include/genn/genn/customConnectivityUpdateInternal.h +++ b/include/genn/genn/customConnectivityUpdateInternal.h @@ -55,13 +55,13 @@ class CustomConnectivityUpdateVarAdapter //---------------------------------------------------------------------------- VarLocation getLoc(const std::string &varName) const{ return m_CU.getVarLocation(varName); } - Models::Base::VarVec getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getVars(); } + std::vector getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getVars(); } const std::unordered_map &getInitialisers() const{ return m_CU.getVarInitialisers(); } const std::string &getNameSuffix() const{ return m_CU.getName(); } - VarAccessDim getVarDims(const Models::Base::Var &var) const{ return var.access.getDims(); } + VarAccessDim getVarDims(const Models::Base::SynapseVar &var) const{ return getAccessDim(var.access); } private: //---------------------------------------------------------------------------- @@ -84,7 +84,7 @@ class CustomConnectivityUpdatePreVarAdapter //---------------------------------------------------------------------------- VarLocation getLoc(const std::string &varName) const{ return m_CU.getPreVarLocation(varName); } - Models::Base::VarVec getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getPreVars(); } + std::vector getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getPreVars(); } const std::unordered_map &getInitialisers() const{ return m_CU.getPreVarInitialisers(); } @@ -92,7 +92,7 @@ class CustomConnectivityUpdatePreVarAdapter const std::string &getNameSuffix() const{ return m_CU.getName(); } - VarAccessDim getVarDims(const Models::Base::Var &var) const{ return var.access.getDims(); } + VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getAccessDim(var.access); } private: //---------------------------------------------------------------------------- @@ -115,7 +115,7 @@ class CustomConnectivityUpdatePostVarAdapter //---------------------------------------------------------------------------- VarLocation getLoc(const std::string &varName) const{ return m_CU.getPostVarLocation(varName); } - Models::Base::VarVec getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getPostVars(); } + std::vector getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getPostVars(); } const std::unordered_map &getInitialisers() const{ return m_CU.getPostVarInitialisers(); } @@ -123,7 +123,7 @@ class CustomConnectivityUpdatePostVarAdapter const std::string &getNameSuffix() const{ return m_CU.getName(); } - VarAccessDim getVarDims(const Models::Base::Var &var) const{ return var.access.getDims(); } + VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getAccessDim(var.access); } private: //---------------------------------------------------------------------------- diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index d6d947ac99..12b4444c12 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -200,7 +200,7 @@ class CustomUpdateVarAdapter //---------------------------------------------------------------------------- VarLocation getLoc(const std::string &varName) const{ return m_CU.getVarLocation(varName); } - Models::Base::VarVec getDefs() const{ return m_CU.getCustomUpdateModel()->getVars(); } + std::vector getDefs() const{ return m_CU.getCustomUpdateModel()->getVars(); } const std::unordered_map &getInitialisers() const{ return m_CU.getVarInitialisers(); } @@ -208,9 +208,9 @@ class CustomUpdateVarAdapter const std::string &getNameSuffix() const{ return m_CU.getName(); } - VarAccessDim getVarDims(const Models::Base::Var &var) const + VarAccessDim getVarDims(const Models::Base::CustomUpdateVar &var) const { - return clearDim(m_CU.getDims(), var.access.getDims()); + getAccessDim(var.access, m_CU.getDims()); } private: diff --git a/include/genn/genn/neuronGroupInternal.h b/include/genn/genn/neuronGroupInternal.h index 5120b5c482..9eac87d979 100644 --- a/include/genn/genn/neuronGroupInternal.h +++ b/include/genn/genn/neuronGroupInternal.h @@ -70,7 +70,7 @@ class NeuronVarAdapter //---------------------------------------------------------------------------- VarLocation getLoc(const std::string &varName) const{ return m_NG.getVarLocation(varName); } - Models::Base::VarVec getDefs() const{ return m_NG.getNeuronModel()->getVars(); } + std::vector getDefs() const{ return m_NG.getNeuronModel()->getVars(); } const std::unordered_map &getInitialisers() const{ return m_NG.getVarInitialisers(); } @@ -78,7 +78,7 @@ class NeuronVarAdapter const std::string &getNameSuffix() const{ return m_NG.getName(); } - VarAccessDim getVarDims(const Models::Base::Var &var) const{ return var.access.getDims(); } + VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getAccessDim(var.access); } private: //---------------------------------------------------------------------------- diff --git a/include/genn/genn/synapseGroupInternal.h b/include/genn/genn/synapseGroupInternal.h index 7031994651..17947d6dd3 100644 --- a/include/genn/genn/synapseGroupInternal.h +++ b/include/genn/genn/synapseGroupInternal.h @@ -110,7 +110,7 @@ class SynapsePSMVarAdapter //---------------------------------------------------------------------------- VarLocation getLoc(const std::string &varName) const{ return m_SG.getPSVarLocation(varName); } - Models::Base::VarVec getDefs() const{ return m_SG.getPSModel()->getVars(); } + std::vector getDefs() const{ return m_SG.getPSModel()->getVars(); } const std::unordered_map &getInitialisers() const{ return m_SG.getPSVarInitialisers(); } @@ -118,7 +118,7 @@ class SynapsePSMVarAdapter bool isVarDelayed(const std::string &) const { return false; } - VarAccessDim getVarDims(const Models::Base::Var &var) const{ return var.access.getDims(); } + VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getAccessDim(var.access); } private: //---------------------------------------------------------------------------- @@ -164,13 +164,13 @@ class SynapseWUVarAdapter //---------------------------------------------------------------------------- VarLocation getLoc(const std::string &varName) const{ return m_SG.getWUVarLocation(varName); } - Models::Base::VarVec getDefs() const{ return m_SG.getWUModel()->getVars(); } + std::vector getDefs() const{ return m_SG.getWUModel()->getVars(); } const std::unordered_map &getInitialisers() const{ return m_SG.getWUVarInitialisers(); } const std::string &getNameSuffix() const{ return m_SG.getName(); } - VarAccessDim getVarDims(const Models::Base::Var &var) const{ return var.access.getDims(); } + VarAccessDim getVarDims(const Models::Base::SynapseVar &var) const{ return getAccessDim(var.access); } private: //---------------------------------------------------------------------------- @@ -193,7 +193,7 @@ class SynapseWUPreVarAdapter //---------------------------------------------------------------------------- VarLocation getLoc(const std::string &varName) const{ return m_SG.getWUPreVarLocation(varName); } - Models::Base::VarVec getDefs() const{ return m_SG.getWUModel()->getPreVars(); } + std::vector getDefs() const{ return m_SG.getWUModel()->getPreVars(); } const std::unordered_map &getInitialisers() const{ return m_SG.getWUPreVarInitialisers(); } @@ -201,7 +201,7 @@ class SynapseWUPreVarAdapter bool isVarDelayed(const std::string&) const{ return (m_SG.getDelaySteps() != 0); } - VarAccessDim getVarDims(const Models::Base::Var &var) const{ return var.access.getDims(); } + VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getAccessDim(var.access); } private: //---------------------------------------------------------------------------- @@ -224,7 +224,7 @@ class SynapseWUPostVarAdapter //---------------------------------------------------------------------------- VarLocation getLoc(const std::string &varName) const{ return m_SG.getWUPostVarLocation(varName); } - Models::Base::VarVec getDefs() const{ return m_SG.getWUModel()->getPostVars(); } + std::vector getDefs() const{ return m_SG.getWUModel()->getPostVars(); } const std::unordered_map &getInitialisers() const{ return m_SG.getWUPostVarInitialisers(); } @@ -232,7 +232,7 @@ class SynapseWUPostVarAdapter bool isVarDelayed(const std::string&) const{ return (m_SG.getBackPropDelaySteps() != 0); } - VarAccessDim getVarDims(const Models::Base::Var &var) const{ return var.access.getDims(); } + VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getAccessDim(var.access); } private: //---------------------------------------------------------------------------- From 5b5efdb4faaee3ee9417aca69cbe04e32fd30e46 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 5 Sep 2023 18:10:33 +0100 Subject: [PATCH 45/60] moved generic model validation and hashing out of base class --- src/genn/genn/currentSourceModels.cc | 17 +++++----- .../genn/customConnectivityUpdateModels.cc | 21 ++----------- src/genn/genn/customUpdateModels.cc | 22 ++++++++++--- src/genn/genn/neuronModels.cc | 17 +++++----- src/genn/genn/postsynapticModels.cc | 17 +++++----- src/genn/genn/weightUpdateModels.cc | 31 ++++--------------- 6 files changed, 52 insertions(+), 73 deletions(-) diff --git a/src/genn/genn/currentSourceModels.cc b/src/genn/genn/currentSourceModels.cc index 9bd9c3e75f..ec1905303f 100644 --- a/src/genn/genn/currentSourceModels.cc +++ b/src/genn/genn/currentSourceModels.cc @@ -19,8 +19,9 @@ boost::uuids::detail::sha1::digest_type Base::getHashDigest() const { // Superclass boost::uuids::detail::sha1 hash; - Models::Base::updateHash(hash); + Snippet::Base::updateHash(hash); + Utils::updateHash(getVars(), hash); Utils::updateHash(getInjectionCode(), hash); return hash.get_digest(); } @@ -30,14 +31,14 @@ void Base::validate(const std::unordered_map ¶mValues, const std::string &description) const { // Superclass - Models::Base::validate(paramValues, varValues, description); + Snippet::Base::validate(paramValues, description); - // If any variables have a reduction access mode, give an error + // Validate variable names const auto vars = getVars(); - if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v){ return !v.access.template isValid(); })) - { - throw std::runtime_error("Current source model variables must have NeuronVarAccess access type"); - } + Utils::validateVecNames(vars, "Variable"); + + // Validate variable initialisers + Utils::validateInitialisers(vars, varValues, "variable", description); + } } // namespace GeNN::CurrentSourceModels \ No newline at end of file diff --git a/src/genn/genn/customConnectivityUpdateModels.cc b/src/genn/genn/customConnectivityUpdateModels.cc index 6715dc6ed1..ae494a2267 100644 --- a/src/genn/genn/customConnectivityUpdateModels.cc +++ b/src/genn/genn/customConnectivityUpdateModels.cc @@ -37,11 +37,12 @@ void Base::validate(const std::unordered_map ¶mValues, const std::string &description) const { // Superclass - Models::Base::validate(paramValues, varValues, description); + Models::Base::validate(paramValues, description); const auto vars = getVars(); const auto preVars = getPreVars(); const auto postVars = getPostVars(); + Utils::validateVecNames(vars, "Variable"); Utils::validateVecNames(preVars, "Presynaptic variable"); Utils::validateVecNames(postVars, "Presynaptic variable"); Utils::validateVecNames(getVarRefs(), "Synapse variable reference"); @@ -49,6 +50,7 @@ void Base::validate(const std::unordered_map ¶mValues, Utils::validateVecNames(getPostVarRefs(), "Postsynaptic variable reference"); // Validate variable initialisers + Utils::validateInitialisers(vars, varValues, "variable", description); Utils::validateInitialisers(preVars, preVarValues, "presynaptic variable", description); Utils::validateInitialisers(postVars, postVarValues, "postsynaptic variable", description); @@ -56,22 +58,5 @@ void Base::validate(const std::unordered_map ¶mValues, Utils::validateInitialisers(getVarRefs(), varRefTargets, "variable reference", description); Utils::validateInitialisers(getPreVarRefs(), preVarRefTargets, "presynaptic variable reference", description); Utils::validateInitialisers(getPostVarRefs(), postVarRefTargets, "postsynaptic variable reference", description); - - // Check variables have suitable access types - if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v){ return !v.access.template isValid(); })) - { - throw std::runtime_error("Custom connectivity update models variables must have SynapseVarAccess access type"); - } - if(std::any_of(preVars.cbegin(), preVars.cend(), - [](const Models::Base::Var &v){ return !v.access.template isValid(); })) - { - throw std::runtime_error("Custom connectivity update models presynaptic variables must have NeuronVarAccess access type"); - } - if(std::any_of(postVars.cbegin(), postVars.cend(), - [](const Models::Base::Var &v){ return !v.access.template isValid(); })) - { - throw std::runtime_error("Custom connectivity update models postsynaptic variables must have NeuronVarAccess access type"); - } } } // namespace GeNN::CustomConnectivityUpdateModels \ No newline at end of file diff --git a/src/genn/genn/customUpdateModels.cc b/src/genn/genn/customUpdateModels.cc index d0f0018a66..984c96e9ff 100644 --- a/src/genn/genn/customUpdateModels.cc +++ b/src/genn/genn/customUpdateModels.cc @@ -17,8 +17,8 @@ boost::uuids::detail::sha1::digest_type Base::getHashDigest() const { // Superclass boost::uuids::detail::sha1 hash; - Models::Base::updateHash(hash); - + Snippet::Base::updateHash(hash); + Utils::updateHash(getVars(), hash); Utils::updateHash(getUpdateCode(), hash); Utils::updateHash(getVarRefs(), hash); Utils::updateHash(getExtraGlobalParamRefs(), hash); @@ -31,7 +31,14 @@ void Base::validate(const std::unordered_map ¶mValues, const std::string &description) const { // Superclass - Models::Base::validate(paramValues, varValues, description); + Snippet::Base::validate(paramValues, description); + + // Validate variable names + const auto vars = getVars(); + Utils::validateVecNames(vars, "Variable"); + + // Validate variable initialisers + Utils::validateInitialisers(vars, varValues, "variable", description); const auto varRefs = getVarRefs(); Utils::validateVecNames(varRefs, "Variable reference"); @@ -47,7 +54,14 @@ void Base::validate(const std::unordered_map ¶mValues, const std::string &description) const { // Superclass - Models::Base::validate(paramValues, varValues, description); + Snippet::Base::validate(paramValues, description); + + // Validate variable names + const auto vars = getVars(); + Utils::validateVecNames(vars, "Variable"); + + // Validate variable initialisers + Utils::validateInitialisers(vars, varValues, "variable", description); const auto varRefs = getVarRefs(); Utils::validateVecNames(getVarRefs(), "Variable reference"); diff --git a/src/genn/genn/neuronModels.cc b/src/genn/genn/neuronModels.cc index ac03382abb..524e02e2ba 100644 --- a/src/genn/genn/neuronModels.cc +++ b/src/genn/genn/neuronModels.cc @@ -28,8 +28,8 @@ boost::uuids::detail::sha1::digest_type Base::getHashDigest() const { // Superclass boost::uuids::detail::sha1 hash; - Models::Base::updateHash(hash); - + Snippet::Base::updateHash(hash); + Utils::updateHash(getVars(), hash); Utils::updateHash(getSimCode(), hash); Utils::updateHash(getThresholdConditionCode(), hash); Utils::updateHash(getResetCode(), hash); @@ -43,16 +43,15 @@ void Base::validate(const std::unordered_map ¶mValues, const std::string &description) const { // Superclass - Models::Base::validate(paramValues, varValues, description); + Models::Base::validate(paramValues, description); Utils::validateVecNames(getAdditionalInputVars(), "Additional input variable"); - // If any variables have an invalid access mode, give an error + // Validate variable names const auto vars = getVars(); - if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v){ return !v.access.template isValid(); })) - { - throw std::runtime_error("Neuron model variables must have NeuronVarAccess access type"); - } + Utils::validateVecNames(vars, "Variable"); + + // Validate variable initialisers + Utils::validateInitialisers(vars, varValues, "variable", description); } } // namespace GeNN::NeuronModels diff --git a/src/genn/genn/postsynapticModels.cc b/src/genn/genn/postsynapticModels.cc index 65bcb36e37..fdfa432461 100644 --- a/src/genn/genn/postsynapticModels.cc +++ b/src/genn/genn/postsynapticModels.cc @@ -19,8 +19,8 @@ boost::uuids::detail::sha1::digest_type Base::getHashDigest() const { // Superclass boost::uuids::detail::sha1 hash; - Models::Base::updateHash(hash); - + Snippet::Base::updateHash(hash); + Utils::updateHash(getVars(), hash); Utils::updateHash(getDecayCode(), hash); Utils::updateHash(getApplyInputCode(), hash); return hash.get_digest(); @@ -31,14 +31,13 @@ void Base::validate(const std::unordered_map ¶mValues, const std::string &description) const { // Superclass - Models::Base::validate(paramValues, varValues, description); + Snippet::Base::validate(paramValues, description); - // If any variables have a reduction access mode, give an error + // Validate variable names const auto vars = getVars(); - if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v){ return !v.access.template isValid(); })) - { - throw std::runtime_error("Postsynaptic model variables must have NeuronVarAccess access type"); - } + Utils::validateVecNames(vars, "Variable"); + + // Validate variable initialisers + Utils::validateInitialisers(vars, varValues, "variable", description); } } // namespace GeNN::PostsynapticModels \ No newline at end of file diff --git a/src/genn/genn/weightUpdateModels.cc b/src/genn/genn/weightUpdateModels.cc index d47f33145b..e3ec73d8d0 100644 --- a/src/genn/genn/weightUpdateModels.cc +++ b/src/genn/genn/weightUpdateModels.cc @@ -20,8 +20,8 @@ boost::uuids::detail::sha1::digest_type Base::getHashDigest() const { // Superclass boost::uuids::detail::sha1 hash; - Models::Base::updateHash(hash); - + Snippet::Base::updateHash(hash); + Utils::updateHash(getVars(), hash); Utils::updateHash(getSimCode(), hash); Utils::updateHash(getEventCode(), hash); Utils::updateHash(getLearnPostCode(), hash); @@ -41,7 +41,6 @@ boost::uuids::detail::sha1::digest_type Base::getHashDigest() const boost::uuids::detail::sha1::digest_type Base::getPreHashDigest() const { // Superclass - // **NOTE** we skip over Models::Base::updateHash to avoid hashing synaptic variables boost::uuids::detail::sha1 hash; Snippet::Base::updateHash(hash); @@ -56,10 +55,8 @@ boost::uuids::detail::sha1::digest_type Base::getPreHashDigest() const boost::uuids::detail::sha1::digest_type Base::getPostHashDigest() const { // Superclass - // **NOTE** we skip over Models::Base::updateHash to avoid hashing synaptic variables boost::uuids::detail::sha1 hash; Snippet::Base::updateHash(hash); - Utils::updateHash(getPostSpikeCode(), hash); Utils::updateHash(getPostDynamicsCode(), hash); Utils::updateHash(getPostVars(), hash); @@ -75,34 +72,18 @@ void Base::validate(const std::unordered_map ¶mValues, const std::string &description) const { // Superclass - Models::Base::validate(paramValues, varValues, description); + Snippet::Base::validate(paramValues, description); - + const auto vars = getVars(); const auto preVars = getPreVars(); const auto postVars = getPostVars(); + Utils::validateVecNames(getVars(), "Variable"); Utils::validateVecNames(getPreVars(), "Presynaptic variable"); Utils::validateVecNames(getPostVars(), "Presynaptic variable"); // Validate variable initialisers + Utils::validateInitialisers(vars, preVarValues, "variable", description); Utils::validateInitialisers(preVars, preVarValues, "presynaptic variable", description); Utils::validateInitialisers(postVars, postVarValues, "postsynaptic variable", description); - - // Check variables have suitable access types - const auto vars = getVars(); - if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v){ return !v.access.template isValid(); })) - { - throw std::runtime_error("Weight update models variables must have SynapseVarAccess access type"); - } - if(std::any_of(preVars.cbegin(), preVars.cend(), - [](const Models::Base::Var &v){ return !v.access.template isValid(); })) - { - throw std::runtime_error("Weight update models presynaptic variables must have NeuronVarAccess access type"); - } - if(std::any_of(postVars.cbegin(), postVars.cend(), - [](const Models::Base::Var &v){ return !v.access.template isValid(); })) - { - throw std::runtime_error("Weight update models postsynaptic variables must have NeuronVarAccess access type"); - } } } // namespace WeightUpdateModels \ No newline at end of file From 609386d64942fe425068b7ebb35b6817a5313762 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 8 Sep 2023 10:31:31 +0100 Subject: [PATCH 46/60] WIP var reference --- .../genn/genn/code_generator/environment.h | 31 +- include/genn/genn/currentSourceInternal.h | 2 +- include/genn/genn/currentSourceModels.h | 2 +- include/genn/genn/customUpdate.h | 6 +- include/genn/genn/models.h | 132 +++--- include/genn/genn/neuronModels.h | 20 +- include/genn/genn/varAccess.h | 18 +- include/genn/genn/weightUpdateModels.h | 8 +- src/genn/genn/code_generator/backendBase.cc | 4 +- .../customConnectivityUpdateGroupMerged.cc | 2 +- .../code_generator/customUpdateGroupMerged.cc | 2 +- src/genn/genn/models.cc | 423 ++++++++++-------- 12 files changed, 359 insertions(+), 291 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index f7fc6844f6..a81bdb6529 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -405,7 +405,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase &(GroupInternal::*)(void) const; - using GetVarIndexFn = std::function; + using GetVarIndexFn = std::function; template using GetVarRefIndexFn = std::function; @@ -695,6 +695,10 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase>> m_Environment; }; +//! Type of a single definition +template +using Def = typename std::invoke_result_t::value_type; + //------------------------------------------------------------------------ // GeNN::CodeGenerator::VarCachePolicy //------------------------------------------------------------------------ @@ -704,8 +708,8 @@ class VarCachePolicy { public: using GroupInternal = typename G::GroupInternal; - using GetIndexFn = std::function; - using ShouldAlwaysCopyFn = std::function; + using GetIndexFn = std::function; + using ShouldAlwaysCopyFn = std::function; VarCachePolicy(GetIndexFn getReadIndex, GetIndexFn getWriteIndex, ShouldAlwaysCopyFn shouldAlwaysCopy = ShouldAlwaysCopyFn()) @@ -721,7 +725,7 @@ class VarCachePolicy //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - bool shouldAlwaysCopy(G&, const Models::Base::Var &var) const + bool shouldAlwaysCopy(G&, const Def &var) const { if(m_ShouldAlwaysCopy) { return m_ShouldAlwaysCopy(var.name, var.access); @@ -731,17 +735,17 @@ class VarCachePolicy } } - std::string getReadIndex(G&, const Models::Base::Var &var) const + std::string getReadIndex(G&, const Def &var) const { return m_GetReadIndex(var.name, var.access); } - std::string getWriteIndex(G&, const Models::Base::Var &var) const + std::string getWriteIndex(G&, const Def &var) const { return m_GetWriteIndex(var.name, var.access); } - std::string getTargetName(const GroupInternal &g, const Models::Base::Var &var) const + std::string getTargetName(const GroupInternal &g, const Def &var) const { return var.name + A(g).getNameSuffix(); } @@ -777,24 +781,24 @@ class VarRefCachePolicy //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - bool shouldAlwaysCopy(G&, const Models::Base::VarRef&) const + bool shouldAlwaysCopy(G&, const Def&) const { // **NOTE** something else is managing the actual variables // and is therefore responsible for copying between delay slots etc return false; } - std::string getReadIndex(G &g, const Models::Base::VarRef &var) const + std::string getReadIndex(G &g, const Def &var) const { return m_GetReadIndex(var.name, A(g.getArchetype()).getInitialisers().at(var.name)); } - std::string getWriteIndex(G &g, const Models::Base::VarRef &var) const + std::string getWriteIndex(G &g, const Def &var) const { return m_GetWriteIndex(var.name, A(g.getArchetype()).getInitialisers().at(var.name)); } - std::string getTargetName(const GroupInternal &g, const Models::Base::VarRef &var) const + std::string getTargetName(const GroupInternal &g, const Def &var) const { const auto &initialiser = A(g).getInitialisers().at(var.name); return initialiser.getVar().name + initialiser.getTargetName(); @@ -815,8 +819,7 @@ class VarRefCachePolicy template class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P { - //! Type of a single definition - using Def = typename std::invoke_result_t::value_type; + public: template @@ -841,7 +844,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P // Copy definitions of variables which have been referenced into new vector or all if always copy set const auto varDefs = archetypeAdapter.getDefs(); - std::vector referencedDefs; + std::vector> referencedDefs; std::copy_if(varDefs.cbegin(), varDefs.cend(), std::back_inserter(referencedDefs), [this](const auto &v) { diff --git a/include/genn/genn/currentSourceInternal.h b/include/genn/genn/currentSourceInternal.h index b0aa62b380..7d9ed39d2a 100644 --- a/include/genn/genn/currentSourceInternal.h +++ b/include/genn/genn/currentSourceInternal.h @@ -54,7 +54,7 @@ class CurrentSourceVarAdapter const std::string &getNameSuffix() const{ return m_CS.getName(); } - VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getVarDims(var.access); } + VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getAccessDim(var.access); } private: //---------------------------------------------------------------------------- diff --git a/include/genn/genn/currentSourceModels.h b/include/genn/genn/currentSourceModels.h index 863c390db4..65709d8253 100644 --- a/include/genn/genn/currentSourceModels.h +++ b/include/genn/genn/currentSourceModels.h @@ -113,7 +113,7 @@ class PoissonExp : public Base "current *= ExpDecay;\n"); SET_PARAM_NAMES({"weight", "tauSyn", "rate"}); - SET_VARS({{"current", "scalar"}}); + SET_NEURON_VARS({{"current", "scalar"}}); SET_DERIVED_PARAMS({ {"ExpDecay", [](const std::unordered_map &pars, double dt){ return std::exp(-dt / pars.at("tauSyn")); }}, {"Init", [](const std::unordered_map &pars, double dt){ return pars.at("weight") * (1.0 - std::exp(-dt / pars.at("tauSyn"))) * (pars.at("tauSyn") / dt); }}, diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index 12b4444c12..30fa9df2bf 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -93,10 +93,10 @@ class GENN_EXPORT CustomUpdateBase // **NOTE** this is correct because custom update variable access types are defined subtractively const auto vars = getCustomUpdateModel()->getVars(); if(std::any_of(vars.cbegin(), vars.cend(), - [reduceDim](const Models::Base::Var &v) + [reduceDim](const Models::Base::CustomUpdateVar &v) { return ((v.access & VarAccessModeAttribute::REDUCE) - && (v.access.getDims() & reduceDim)); + && (v.access & reduceDim)); })) { return true; @@ -108,7 +108,7 @@ class GENN_EXPORT CustomUpdateBase // and the variable it targets doesn't have reduction dimension const auto &varRef = varRefs.at(modelVarRef.name); if ((modelVarRef.access & VarAccessModeAttribute::REDUCE) - && !(varRef.getDims() & reduceDim)) + && !(varRef.getVarDims() & reduceDim)) { return true; } diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 1b2fd0cc90..f764a06f01 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -33,8 +33,10 @@ class CustomConnectivityUpdateInternal; //---------------------------------------------------------------------------- // Macros //---------------------------------------------------------------------------- -#define SET_VARS(...) virtual VarVec getVars() const override{ return __VA_ARGS__; } -#define DEFINE_REF_DETAIL_STRUCT(NAME, GROUP_TYPE) using NAME = Detail +#define SET_NEURON_VARS(...) virtual std::vector getVars() const override{ return __VA_ARGS__; } +#define SET_SYNAPSE_VARS(...) virtual std::vector getVars() const override{ return __VA_ARGS__; } +#define SET_CUSTOM_UPDATE_VARS(...) virtual std::vector getVars() const override{ return __VA_ARGS__; } +#define DEFINE_REF_DETAIL_STRUCT(NAME, GROUP_TYPE, VAR_TYPE) using NAME = Detail //---------------------------------------------------------------------------- // GeNN::Models::Base @@ -74,6 +76,8 @@ class GENN_EXPORT Base : public Snippet::Base struct NeuronVar : public VarBase { + using VarBase::VarBase; + NeuronVar(const std::string &n, const Type::ResolvedType &t) : VarBase(n, t, NeuronVarAccess::READ_WRITE) {} @@ -84,6 +88,8 @@ class GENN_EXPORT Base : public Snippet::Base struct SynapseVar : public VarBase { + using VarBase::VarBase; + SynapseVar(const std::string &n, const Type::ResolvedType &t) : VarBase(n, t, SynapseVarAccess::READ_WRITE) {} @@ -94,6 +100,8 @@ class GENN_EXPORT Base : public Snippet::Base struct CustomUpdateVar : public VarBase { + using VarBase::VarBase; + CustomUpdateVar(const std::string &n, const Type::ResolvedType &t) : VarBase(n, t, CustomUpdateVarAccess::READ_WRITE) {} @@ -139,18 +147,6 @@ class GENN_EXPORT Base : public Snippet::Base //---------------------------------------------------------------------------- typedef std::vector VarRefVec; typedef std::vector EGPRefVec; - -protected: - //------------------------------------------------------------------------ - // Protected methods - //------------------------------------------------------------------------ - void updateHash(boost::uuids::detail::sha1 &hash) const; - - //! Validate names of parameters etc - void validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, - const std::string &description) const; - }; //---------------------------------------------------------------------------- @@ -158,36 +154,18 @@ class GENN_EXPORT Base : public Snippet::Base //---------------------------------------------------------------------------- class GENN_EXPORT VarReferenceBase { -public: - //------------------------------------------------------------------------ - // Public API - //------------------------------------------------------------------------ - const Models::Base::Var &getVar() const { return m_Var; } - size_t getVarIndex() const { return m_VarIndex; } - - protected: //------------------------------------------------------------------------ // Detail //------------------------------------------------------------------------ //! Minimal helper class for definining unique struct //! wrappers around group pointers for use with std::variant - template + template struct Detail { G *group; + V var; }; - - VarReferenceBase(size_t varIndex, const Models::Base::VarVec &varVec) - : m_VarIndex(varIndex), m_Var(varVec.at(varIndex)) - {} - -private: - //------------------------------------------------------------------------ - // Members - //------------------------------------------------------------------------ - size_t m_VarIndex; - Models::Base::Var m_Var; }; //---------------------------------------------------------------------------- @@ -199,18 +177,24 @@ class GENN_EXPORT VarReference : public VarReferenceBase //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ + //! Get name of variable + const std::string &getVarName() const; + + // Get type of variable + const Type::UnresolvedType &getVarType() const; + + // Get dimensions of variable + VarAccessDim getVarDims() const; + //! Get size of variable - unsigned int getSize() const { return m_Size; } + unsigned int getSize() const; //! If variable is delayed, get neuron group which manages its delay NeuronGroup *getDelayNeuronGroup() const; //! Get suffix to use when accessing target variable names // **TODO** rename to getNameSuffix - std::string getTargetName() const; - - //! Get dimensions of variable being referenced - VarAccessDim getDims() const; + const std::string &getTargetName() const; //! If this reference points to another custom update, return pointer to it /*! This is used to detect circular dependencies */ @@ -237,28 +221,25 @@ class GENN_EXPORT VarReference : public VarReferenceBase //------------------------------------------------------------------------ // Typedefines //------------------------------------------------------------------------ - DEFINE_REF_DETAIL_STRUCT(NGRef, NeuronGroupInternal); - DEFINE_REF_DETAIL_STRUCT(PSMRef, SynapseGroupInternal); - DEFINE_REF_DETAIL_STRUCT(WUPreRef, SynapseGroupInternal); - DEFINE_REF_DETAIL_STRUCT(WUPostRef, SynapseGroupInternal); - DEFINE_REF_DETAIL_STRUCT(CSRef, CurrentSourceInternal); - DEFINE_REF_DETAIL_STRUCT(CURef, CustomUpdateInternal); - DEFINE_REF_DETAIL_STRUCT(CCUPreRef, CustomConnectivityUpdateInternal); - DEFINE_REF_DETAIL_STRUCT(CCUPostRef, CustomConnectivityUpdateInternal); + DEFINE_REF_DETAIL_STRUCT(NGRef, NeuronGroupInternal, Base::NeuronVar); + DEFINE_REF_DETAIL_STRUCT(PSMRef, SynapseGroupInternal, Base::NeuronVar); + DEFINE_REF_DETAIL_STRUCT(WUPreRef, SynapseGroupInternal, Base::NeuronVar); + DEFINE_REF_DETAIL_STRUCT(WUPostRef, SynapseGroupInternal, Base::NeuronVar); + DEFINE_REF_DETAIL_STRUCT(CSRef, CurrentSourceInternal, Base::NeuronVar); + DEFINE_REF_DETAIL_STRUCT(CURef, CustomUpdateInternal, Base::CustomUpdateVar); + DEFINE_REF_DETAIL_STRUCT(CCUPreRef, CustomConnectivityUpdateInternal, Base::NeuronVar); + DEFINE_REF_DETAIL_STRUCT(CCUPostRef, CustomConnectivityUpdateInternal, Base::NeuronVar); //! Variant type used to store 'detail' using DetailType = std::variant; - VarReference(size_t varIndex, const Models::Base::VarVec &varVec, unsigned int size, - const DetailType &detail) - : VarReferenceBase(varIndex, varVec), m_Size(size), m_Detail(detail) + VarReference(const DetailType &detail) : m_Detail(detail) {} //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - unsigned int m_Size; DetailType m_Detail; }; @@ -271,23 +252,33 @@ class GENN_EXPORT WUVarReference : public VarReferenceBase //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - const Models::Base::Var &getTransposeVar() const { return *m_TransposeVar; } - size_t getTransposeVarIndex() const { return *m_TransposeVarIndex; } + //! Get name of variable + const std::string &getVarName() const; + + // Get type of variable + const Type::UnresolvedType &getVarType() const; + + // Get dimensions of variable + VarAccessDim getVarDims() const; //! Get suffix to use when accessing target variable names // **TODO** rename to getNameSuffix - std::string getTargetName() const; + const std::string &getTargetName() const; - //! Get dimensions of variable being referenced - VarAccessDim getDims() const{ return getVarDims(getVar()); } - SynapseGroup *getSynapseGroup() const; + + //! Get name of tranpose variable + std::optional getTransposeVarName() const; - SynapseGroup *getTransposeSynapseGroup() const; - std::string getTransposeTargetName() const; + // Get type of transpose variable + std::optional getTransposeVarType() const; //! Get dimensions of transpose variable being referenced - VarAccessDim getTransposeDims() const{ return getVarDims(getTransposeVar()); } + std::optional getTransposeVarDims() const; + + std::optional getTransposeTargetName() const; + + SynapseGroup *getTransposeSynapseGroup() const; //! If this reference points to another custom update, return pointer to it /*! This is used to detect circular dependencies */ @@ -316,13 +307,16 @@ class GENN_EXPORT WUVarReference : public VarReferenceBase { SynapseGroupInternal *group; SynapseGroupInternal *transposeGroup; + + Base::SynapseVar var; + std::optional transposeVar; }; //------------------------------------------------------------------------ // Typedefines //------------------------------------------------------------------------ - DEFINE_REF_DETAIL_STRUCT(CURef, CustomUpdateWUInternal); - DEFINE_REF_DETAIL_STRUCT(CCURef, CustomConnectivityUpdateInternal); + DEFINE_REF_DETAIL_STRUCT(CURef, CustomUpdateWUInternal, Base::CustomUpdateVar); + DEFINE_REF_DETAIL_STRUCT(CCURef, CustomConnectivityUpdateInternal, Base::SynapseVar); //! Variant type used to store 'detail' using DetailType = std::variant; @@ -332,20 +326,12 @@ class GENN_EXPORT WUVarReference : public VarReferenceBase //------------------------------------------------------------------------ SynapseGroupInternal *getSynapseGroupInternal() const; SynapseGroupInternal *getTransposeSynapseGroupInternal() const; - VarAccessDim getVarDims(const Models::Base::Var &var) const; - WUVarReference(size_t varIndex, const Models::Base::VarVec &varVec, - const DetailType &detail); - WUVarReference(size_t varIndex, const Models::Base::VarVec &varVec, - size_t transposeVarIndex, const Models::Base::VarVec &transposeVarVec, - const DetailType &detail); + WUVarReference(const DetailType &detail); //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - std::optional m_TransposeVarIndex; - std::optional m_TransposeVar; - DetailType m_Detail; }; @@ -388,7 +374,9 @@ class GENN_EXPORT EGPReference //---------------------------------------------------------------------------- // updateHash overrides //---------------------------------------------------------------------------- -GENN_EXPORT void updateHash(const Base::Var &v, boost::uuids::detail::sha1 &hash); +GENN_EXPORT void updateHash(const Base::NeuronVar &v, boost::uuids::detail::sha1 &hash); +GENN_EXPORT void updateHash(const Base::SynapseVar &v, boost::uuids::detail::sha1 &hash); +GENN_EXPORT void updateHash(const Base::CustomUpdateVar &v, boost::uuids::detail::sha1 &hash); GENN_EXPORT void updateHash(const Base::VarRef &v, boost::uuids::detail::sha1 &hash); GENN_EXPORT void updateHash(const Base::EGPRef &e, boost::uuids::detail::sha1 &hash); GENN_EXPORT void updateHash(const VarReference &v, boost::uuids::detail::sha1 &hash); diff --git a/include/genn/genn/neuronModels.h b/include/genn/genn/neuronModels.h index 8b1ec5bc10..1e2dccb3be 100644 --- a/include/genn/genn/neuronModels.h +++ b/include/genn/genn/neuronModels.h @@ -128,7 +128,7 @@ class RulkovMap : public Base SET_THRESHOLD_CONDITION_CODE("$(V) >= $(ip2)"); SET_PARAM_NAMES({"Vspike", "alpha", "y", "beta"}); - SET_VARS({{"V","scalar"}, {"preV", "scalar"}}); + SET_NEURON_VARS({{"V","scalar"}, {"preV", "scalar"}}); SET_DERIVED_PARAMS({ {"ip0", [](const std::unordered_map &pars, double){ return pars.at("Vspike") * pars.at("Vspike") * pars.at("alpha"); }}, @@ -177,7 +177,7 @@ class Izhikevich : public Base SET_THRESHOLD_CONDITION_CODE("$(V) >= 29.99"); SET_PARAM_NAMES({"a", "b", "c", "d"}); - SET_VARS({{"V","scalar"}, {"U", "scalar"}}); + SET_NEURON_VARS({{"V","scalar"}, {"U", "scalar"}}); SET_NEEDS_AUTO_REFRACTORY(false); }; @@ -205,9 +205,9 @@ class IzhikevichVariable : public Izhikevich DECLARE_SNIPPET(NeuronModels::IzhikevichVariable); SET_PARAM_NAMES({}); - SET_VARS({{"V","scalar"}, {"U", "scalar"}, - {"a", "scalar", NeuronVarAccess::READ_ONLY}, {"b", "scalar", NeuronVarAccess::READ_ONLY}, - {"c", "scalar", NeuronVarAccess::READ_ONLY}, {"d", "scalar", NeuronVarAccess::READ_ONLY}}); + SET_NEURON_VARS({{"V","scalar"}, {"U", "scalar"}, + {"a", "scalar", NeuronVarAccess::READ_ONLY}, {"b", "scalar", NeuronVarAccess::READ_ONLY}, + {"c", "scalar", NeuronVarAccess::READ_ONLY}, {"d", "scalar", NeuronVarAccess::READ_ONLY}}); }; //---------------------------------------------------------------------------- @@ -247,7 +247,7 @@ class LIF : public Base {"ExpTC", [](const std::unordered_map &pars, double dt){ return std::exp(-dt / pars.at("TauM")); }}, {"Rmembrane", [](const std::unordered_map &pars, double){ return pars.at("TauM") / pars.at("C"); }}}); - SET_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); + SET_NEURON_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); SET_NEEDS_AUTO_REFRACTORY(false); }; @@ -291,7 +291,7 @@ class SpikeSourceArray : public Base "$(startSpike) != $(endSpike) && " "$(t) >= $(spikeTimes)[$(startSpike)]" ); SET_RESET_CODE( "$(startSpike)++;\n" ); - SET_VARS( {{"startSpike", "unsigned int"}, {"endSpike", "unsigned int", NeuronVarAccess::READ_ONLY_DUPLICATE}} ); + SET_NEURON_VARS({{"startSpike", "unsigned int"}, {"endSpike", "unsigned int", NeuronVarAccess::READ_ONLY_DUPLICATE}}); SET_EXTRA_GLOBAL_PARAMS( {{"spikeTimes", "scalar*"}} ); SET_NEEDS_AUTO_REFRACTORY(false); }; @@ -351,7 +351,7 @@ class Poisson : public Base SET_THRESHOLD_CONDITION_CODE("$(V) >= $(Vspike)"); SET_PARAM_NAMES({"trefract", "tspike", "Vspike", "Vrest"}); - SET_VARS({{"V", "scalar"}, {"spikeTime", "scalar"}}); + SET_NEURON_VARS({{"V", "scalar"}, {"spikeTime", "scalar"}}); SET_EXTRA_GLOBAL_PARAMS({{"firingProb", "scalar*"}, {"offset", "unsigned int"}}); }; @@ -387,7 +387,7 @@ class PoissonNew : public Base SET_THRESHOLD_CONDITION_CODE("$(timeStepToSpike) <= 0.0"); SET_PARAM_NAMES({"rate"}); - SET_VARS({{"timeStepToSpike", "scalar"}}); + SET_NEURON_VARS({{"timeStepToSpike", "scalar"}}); SET_DERIVED_PARAMS({{"isi", [](const std::unordered_map &pars, double dt){ return 1000.0 / (pars.at("rate") * dt); }}}); SET_NEEDS_AUTO_REFRACTORY(false); }; @@ -485,7 +485,7 @@ class TraubMiles : public Base SET_THRESHOLD_CONDITION_CODE("$(V) >= 0.0"); SET_PARAM_NAMES({"gNa", "ENa", "gK", "EK", "gl", "El", "C"}); - SET_VARS({{"V", "scalar"}, {"m", "scalar"}, {"h", "scalar"}, {"n", "scalar"}}); + SET_NEURON_VARS({{"V", "scalar"}, {"m", "scalar"}, {"h", "scalar"}, {"n", "scalar"}}); }; //---------------------------------------------------------------------------- diff --git a/include/genn/genn/varAccess.h b/include/genn/genn/varAccess.h index fa27989bff..280b031652 100644 --- a/include/genn/genn/varAccess.h +++ b/include/genn/genn/varAccess.h @@ -100,15 +100,29 @@ inline bool operator & (VarAccessMode mode, VarAccessModeAttribute modeAttribute return (static_cast(mode) & static_cast(modeAttribute)) != 0; } +inline bool operator & (NeuronVarAccess mode, VarAccessModeAttribute modeAttribute) +{ + return (static_cast(mode) & static_cast(modeAttribute)) != 0; +} + +inline bool operator & (SynapseVarAccess mode, VarAccessModeAttribute modeAttribute) +{ + return (static_cast(mode) & static_cast(modeAttribute)) != 0; +} + +inline bool operator & (CustomUpdateVarAccess mode, VarAccessModeAttribute modeAttribute) +{ + return (static_cast(mode) & static_cast(modeAttribute)) != 0; +} inline bool operator & (VarAccessDim a, VarAccessDim b) { return (static_cast(a) & static_cast(b)) != 0; } -inline VarAccessDim operator | (VarAccessDim a, VarAccessDim b) +/*inline VarAccessDim operator | (VarAccessDim a, VarAccessDim b) { return static_cast(static_cast(a) | static_cast(b)); -} +}*/ inline VarAccessDim clearDim(VarAccessDim a, VarAccessDim b) diff --git a/include/genn/genn/weightUpdateModels.h b/include/genn/genn/weightUpdateModels.h index 038a74b48e..08d8b008cc 100644 --- a/include/genn/genn/weightUpdateModels.h +++ b/include/genn/genn/weightUpdateModels.h @@ -140,7 +140,7 @@ class StaticPulse : public Base public: DECLARE_SNIPPET(StaticPulse); - SET_VARS({{"g", "scalar", SynapseVarAccess::READ_ONLY}}); + SET_SYNAPSE_VARS({{"g", "scalar", SynapseVarAccess::READ_ONLY}}); SET_SIM_CODE("addToPost(g);\n"); }; @@ -191,7 +191,7 @@ class StaticPulseDendriticDelay : public Base public: DECLARE_SNIPPET(StaticPulseDendriticDelay); - SET_VARS({{"g", "scalar", SynapseVarAccess::READ_ONLY}, {"d", "uint8_t", SynapseVarAccess::READ_ONLY}}); + SET_SYNAPSE_VARS({{"g", "scalar", SynapseVarAccess::READ_ONLY}, {"d", "uint8_t", SynapseVarAccess::READ_ONLY}}); SET_SIM_CODE("addToPostDelay(g, d);\n"); }; @@ -228,7 +228,7 @@ class StaticGraded : public Base DECLARE_SNIPPET(StaticGraded); SET_PARAM_NAMES({"Epre", "Vslope"}); - SET_VARS({{"g", "scalar", SynapseVarAccess::READ_ONLY}}); + SET_SYNAPSE_VARS({{"g", "scalar", SynapseVarAccess::READ_ONLY}}); SET_EVENT_CODE("addToPost(fmax(0.0, g * tanh((V_pre - Epre) / Vslope) * DT));\n"); @@ -299,7 +299,7 @@ class PiecewiseSTDP : public Base SET_PARAM_NAMES({"tLrn", "tChng", "tDecay", "tPunish10", "tPunish01", "gMax", "gMid", "gSlope", "tauShift", "gSyn0"}); - SET_VARS({{"g", "scalar"}, {"gRaw", "scalar"}}); + SET_SYNAPSE_VARS({{"g", "scalar"}, {"gRaw", "scalar"}}); SET_SIM_CODE( "addToPost(g);\n" diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 1fbae498d8..3be1e8db6e 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -706,7 +706,7 @@ std::vector BackendBase::genInitReductionTargets(C [batchSize, &cg](const Models::VarReference &varRef, const std::string &index) { return cg.getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, batchSize, - varRef.getDims(), index); + varRef.getVarDims(), index); }); } //----------------------------------------------------------------------- @@ -717,7 +717,7 @@ std::vector BackendBase::genInitReductionTargets(C os, cg, batchSize, idx, [batchSize, &cg](const Models::WUVarReference &varRef, const std::string &index) { - return cg.getVarRefIndex(batchSize, varRef.getDims(), index); + return cg.getVarRefIndex(batchSize, varRef.getVarDims(), index); }); } } // namespace GeNN::CodeGenerator \ No newline at end of file diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index d68f4fc38f..1f68b58e53 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -115,7 +115,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back for(const auto &v : getArchetype().getCustomConnectivityUpdateModel()->getPreVarRefs()) { // If model isn't batched or variable isn't duplicated const auto &varRef = getArchetype().getPreVarReferences().at(v.name); - if(batchSize == 1 || !(varRef.getDims() & VarAccessDim::BATCH)) { + if(batchSize == 1 || !(varRef.getVarDims() & VarAccessDim::BATCH)) { // Determine index const std::string index = (varRef.getDelayNeuronGroup() != nullptr) ? "$(_pre_delay_offset) + $(id_pre)" : "$(id_pre)"; diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 1d5041181b..60b55669f8 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -71,7 +71,7 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E [this, batchSize, &varEnv](const std::string&, const Models::VarReference &v) { return getVarRefIndex(v.getDelayNeuronGroup() != nullptr, batchSize, - v.getDims(), "$(id)"); + v.getVarDims(), "$(id)"); }); Transpiler::ErrorHandler errorHandler("Custom update '" + getArchetype().getName() + "' update code"); diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index 90bd392553..bd1ae5e25b 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -18,39 +18,70 @@ Base::EGPRef::EGPRef(const std::string &n, const std::string &t) } //---------------------------------------------------------------------------- -// GeNN::Models::Base +// VarReference //---------------------------------------------------------------------------- -void Base::updateHash(boost::uuids::detail::sha1 &hash) const +const std::string &VarReference::getVarName() const { - // Superclass - Snippet::Base::updateHash(hash); - - Utils::updateHash(getVars(), hash); + return std::visit( + Utils::Overload{[](const auto &ref){ return ref.var.name; }}, + m_Detail); } //---------------------------------------------------------------------------- -void Base::validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, - const std::string &description) const +const Type::UnresolvedType &VarReference::getVarType() const { - // Superclass - Snippet::Base::validate(paramValues, description); - - const auto vars = getVars(); - Utils::validateVecNames(vars, "Variable"); - - // Validate variable initialisers - Utils::validateInitialisers(vars, varValues, "variable", description); + return std::visit( + Utils::Overload{[](const auto &ref){ return ref.var.type; }}, + m_Detail); } - //---------------------------------------------------------------------------- -// VarReference +VarAccessDim VarReference::getVarDims() const +{ + return std::visit( + Utils::Overload{ + // If reference is to a custom update variable, + // remove dimensions from those of update + [](const CURef &ref) + { + return getAccessDim(ref.var.access, ref.group->getDims()); + }, + // Otherwise, if reference is to the presynaptic variables of a custom connectivity update, + // remove BATCH dimension as these are never batched + [](const CCUPreRef &ref) + { + return clearDim(getAccessDim(ref.var.access), VarAccessDim::BATCH); + }, + // Otherwise, if reference is to the postsynaptic variables of a custom connectivity update, + // remove BATCH dimension as these are never batched + [](const CCUPostRef &ref) + { + return clearDim(getAccessDim(ref.var.access), VarAccessDim::BATCH); + }, + // Otherwise, use dimensionality directly + [](const auto &ref) { return getAccessDim(ref.var.access); }}, + m_Detail); +} +//---------------------------------------------------------------------------- +unsigned int VarReference::getSize() const +{ + return std::visit( + Utils::Overload{ + [](const NGRef &ref) { return ref.group->getNumNeurons(); }, + [](const PSMRef &ref) { return ref.group->getTrgNeuronGroup()->getNumNeurons(); }, + [](const WUPreRef &ref) { return ref.group->getSrcNeuronGroup()->getNumNeurons(); }, + [](const WUPostRef &ref) { return ref.group->getTrgNeuronGroup()->getNumNeurons(); }, + [](const CSRef &ref) { return ref.group->getTrgNeuronGroup()->getNumNeurons(); }, + [](const CURef &ref) { return ref.group->getSize(); }, + [](const CCUPreRef &ref) { return ref.group->getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(); }, + [](const CCUPostRef &ref) { return ref.group->getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons(); }}, + m_Detail); +} //---------------------------------------------------------------------------- NeuronGroup *VarReference::getDelayNeuronGroup() const { return std::visit( Utils::Overload{ [this](const NGRef &ref)->NeuronGroup* { - return (ref.group->isDelayRequired() && ref.group->isVarQueueRequired(getVar().name)) ? ref.group : nullptr; + return (ref.group->isDelayRequired() && ref.group->isVarQueueRequired(ref.var.name)) ? ref.group : nullptr; }, [](const WUPreRef &ref)->NeuronGroup* { return (ref.group->getDelaySteps() > 0) ? ref.group->getSrcNeuronGroup() : nullptr; @@ -62,10 +93,10 @@ NeuronGroup *VarReference::getDelayNeuronGroup() const m_Detail); } //---------------------------------------------------------------------------- -std::string VarReference::getTargetName() const +const std::string &VarReference::getTargetName() const { return std::visit( - Utils::Overload{ + Utils::Overload{ [](const PSMRef &ref) { return ref.group->getFusedPSVarSuffix(); }, [](const WUPreRef &ref) { return ref.group->getFusedWUPreVarSuffix(); }, [](const WUPostRef &ref) { return ref.group->getFusedWUPostVarSuffix(); }, @@ -73,37 +104,6 @@ std::string VarReference::getTargetName() const m_Detail); } //---------------------------------------------------------------------------- -VarAccessDim VarReference::getDims() const -{ - const auto &varAccess = getVar().access; - return std::visit( - Utils::Overload{ - // If reference is to a custom update variable, - // remove dimensions from those of update - [&varAccess](const CURef &ref) - { - return clearDim(ref.group->getDims(), - varAccess.getDims()); - }, - // Otherwise, if reference is to the presynaptic variables of a custom connectivity update, - // remove BATCH dimension as these are never batched - [&varAccess](const CCUPreRef&) - { - return clearDim(varAccess.getDims(), - VarAccessDim::BATCH); - }, - // Otherwise, if reference is to the postsynaptic variables of a custom connectivity update, - // remove BATCH dimension as these are never batched - [&varAccess](const CCUPostRef&) - { - return clearDim(varAccess.getDims(), - VarAccessDim::BATCH); - }, - // Otherwise, use dimensionality directly - [&varAccess](const auto&) { return varAccess.getDims(); }}, - m_Detail); -} -//---------------------------------------------------------------------------- CustomUpdate *VarReference::getReferencedCustomUpdate() const { return std::visit( @@ -118,73 +118,115 @@ bool VarReference::operator < (const VarReference &other) const // **NOTE** variable and target names are enough to guarantee uniqueness const std::string targetName = getTargetName(); const std::string otherTargetName = other.getTargetName(); - return (std::tie(getVar().name, targetName) < std::tie(other.getVar().name, otherTargetName)); + + return std::visit( + Utils::Overload{ + [&targetName, &otherTargetName](const auto &detail, const auto &otherDetail) + { + return (std::tie(detail.var.name, targetName) + < std::tie(otherDetail.var.name, otherTargetName)); + }}, + m_Detail, other.m_Detail); } //---------------------------------------------------------------------------- VarReference VarReference::createVarRef(NeuronGroup *ng, const std::string &varName) { - return VarReference(ng->getNeuronModel()->getVarIndex(varName), ng->getNeuronModel()->getVars(), - ng->getNumNeurons(), NGRef{static_cast(ng)}); + const auto *nm = ng->getNeuronModel(); + return VarReference(NGRef{static_cast(ng), + nm->getVars()[nm->getVarIndex(varName)]}); } //---------------------------------------------------------------------------- VarReference VarReference::createVarRef(CurrentSource *cs, const std::string &varName) { - auto *csInternal = static_cast(cs); - return VarReference(cs->getCurrentSourceModel()->getVarIndex(varName), cs->getCurrentSourceModel()->getVars(), - csInternal->getTrgNeuronGroup()->getNumNeurons(), CSRef{csInternal}); + const auto *csm = cs->getCurrentSourceModel(); + return VarReference(CSRef{static_cast(cs), + csm->getVars()[csm->getVarIndex(varName)]}); } //---------------------------------------------------------------------------- VarReference VarReference::createVarRef(CustomUpdate *cu, const std::string &varName) { - return VarReference(cu->getCustomUpdateModel()->getVarIndex(varName), cu->getCustomUpdateModel()->getVars(), - cu->getSize(), CURef{static_cast(cu)}); + const auto *cum = cu->getCustomUpdateModel(); + return VarReference(CURef{static_cast(cu), + cum->getVars()[cum->getVarIndex(varName)]}); } //---------------------------------------------------------------------------- VarReference VarReference::createPreVarRef(CustomConnectivityUpdate *ccu, const std::string &varName) { - auto *ccuInternal = static_cast(ccu); - auto *sg = ccuInternal->getSynapseGroup(); - return VarReference(ccu->getCustomConnectivityUpdateModel()->getPreVarIndex(varName), ccu->getCustomConnectivityUpdateModel()->getPreVars(), - sg->getSrcNeuronGroup()->getNumNeurons(), CCUPreRef{ccuInternal}); + const auto *ccum = ccu->getCustomConnectivityUpdateModel(); + return VarReference(CCUPreRef{static_cast(ccu), + ccum->getPreVars()[ccum->getPreVarIndex(varName)]}); } //---------------------------------------------------------------------------- VarReference VarReference::createPostVarRef(CustomConnectivityUpdate *ccu, const std::string &varName) { - auto *ccuInternal = static_cast(ccu); - auto *sg = ccuInternal->getSynapseGroup(); - return VarReference(ccu->getCustomConnectivityUpdateModel()->getPostVarIndex(varName), ccu->getCustomConnectivityUpdateModel()->getPostVars(), - sg->getTrgNeuronGroup()->getNumNeurons(), CCUPostRef{ccuInternal}); + const auto *ccum = ccu->getCustomConnectivityUpdateModel(); + return VarReference(CCUPostRef{static_cast(ccu), + ccum->getPostVars()[ccum->getPostVarIndex(varName)]}); } //---------------------------------------------------------------------------- VarReference VarReference::createPSMVarRef(SynapseGroup *sg, const std::string &varName) { - auto *sgInternal = static_cast(sg); - return VarReference(sg->getPSModel()->getVarIndex(varName), sg->getPSModel()->getVars(), - sgInternal->getTrgNeuronGroup()->getNumNeurons(), PSMRef{sgInternal}); + const auto *psm = sg->getPSModel(); + return VarReference(PSMRef{static_cast(sg), + psm->getVars()[psm->getVarIndex(varName)]}); } //---------------------------------------------------------------------------- VarReference VarReference::createWUPreVarRef(SynapseGroup *sg, const std::string &varName) { - auto *sgInternal = static_cast(sg); - return VarReference(sg->getWUModel()->getPreVarIndex(varName), sg->getWUModel()->getPreVars(), - sgInternal->getSrcNeuronGroup()->getNumNeurons(), WUPreRef{sgInternal}); + const auto *wum = sg->getWUModel(); + return VarReference(WUPreRef{static_cast(sg), + wum->getPreVars()[wum->getPreVarIndex(varName)]}); } //---------------------------------------------------------------------------- VarReference VarReference::createWUPostVarRef(SynapseGroup *sg, const std::string &varName) { - auto *sgInternal = static_cast(sg); - return VarReference(sg->getWUModel()->getPostVarIndex(varName), sg->getWUModel()->getPostVars(), - sgInternal->getTrgNeuronGroup()->getNumNeurons(), WUPostRef{sgInternal}); + const auto *wum = sg->getWUModel(); + return VarReference(WUPostRef{static_cast(sg), + wum->getPostVars()[wum->getPostVarIndex(varName)]}); } //---------------------------------------------------------------------------- // WUVarReference //---------------------------------------------------------------------------- -std::string WUVarReference::getTargetName() const +const std::string &WUVarReference::getVarName() const +{ + return std::visit( + Utils::Overload{[](const auto &ref){ return ref.var.name; }}, + m_Detail); +} +//---------------------------------------------------------------------------- +const Type::UnresolvedType &WUVarReference::getVarType() const +{ + return std::visit( + Utils::Overload{[](const auto &ref){ return ref.var.type; }}, + m_Detail); +} +//---------------------------------------------------------------------------- +VarAccessDim WUVarReference::getVarDims() const { return std::visit( Utils::Overload{ - [](const auto &ref) { return ref.group->getName(); }}, + // If reference is to a custom update variable, + // remove dimensions from those of update + [](const CURef &ref) + { + return getAccessDim(ref.var.access, ref.group->getDims()); + }, + // Otherwise, if reference is to the synaptic variables of a custom connectivity update, + // remove BATCH dimension as these are never batched + [](const CCURef &ref) + { + return clearDim(getAccessDim(ref.var.access), VarAccessDim::BATCH); + }, + // Otherwise, use dimensionality directly + [](const WURef &ref){ return getAccessDim(ref.var.access); }}, + m_Detail); +} +//---------------------------------------------------------------------------- +const std::string &WUVarReference::getTargetName() const +{ + return std::visit( + Utils::Overload{[](const auto &ref) { return ref.group->getName(); }}, m_Detail); } //---------------------------------------------------------------------------- @@ -193,20 +235,73 @@ SynapseGroup *WUVarReference::getSynapseGroup() const return getSynapseGroupInternal(); } //------------------------------------------------------------------------ -SynapseGroup *WUVarReference::getTransposeSynapseGroup() const +std::optional WUVarReference::getTransposeVarName() const { - return getTransposeSynapseGroupInternal(); + return std::visit( + Utils::Overload{ + [](const WURef &ref)->std::optional + { + if(ref.transposeVar) { + return ref.transposeVar->name; + } + else { + return std::nullopt; + } + }, + [](const auto&){ return std::nullopt; }}, + m_Detail); +} +//------------------------------------------------------------------------ +std::optional WUVarReference::getTransposeVarType() const +{ + return std::visit( + Utils::Overload{ + [](const WURef &ref)->std::optional + { + if(ref.transposeVar) { + return ref.transposeVar->type; + } + else { + return std::nullopt; + } + }, + [](const auto&){ return std::nullopt; }}, + m_Detail); } //------------------------------------------------------------------------ -std::string WUVarReference::getTransposeTargetName() const +std::optional WUVarReference::getTransposeVarDims() const { return std::visit( Utils::Overload{ - [](const WURef &ref) { return ref.transposeGroup->getName(); }, - [](const auto&)->std::string { throw std::runtime_error("No transpose"); }}, + [](const WURef &ref)->std::optional + { + if(ref.transposeVar) { + return getAccessDim(ref.transposeVar->access); + } + else { + return std::nullopt; + } + }, + [](const auto&){ return std::nullopt; }}, m_Detail); } //------------------------------------------------------------------------ +SynapseGroup *WUVarReference::getTransposeSynapseGroup() const +{ + return getTransposeSynapseGroupInternal(); +} +//------------------------------------------------------------------------ +std::optional WUVarReference::getTransposeTargetName() const +{ + const auto *transposeSG = getTransposeSynapseGroup(); + if(transposeSG) { + return transposeSG->getName(); + } + else { + return std::nullopt; + } +} +//------------------------------------------------------------------------ CustomUpdateWU *WUVarReference::getReferencedCustomUpdate() const { return std::visit( @@ -218,49 +313,39 @@ CustomUpdateWU *WUVarReference::getReferencedCustomUpdate() const //------------------------------------------------------------------------ bool WUVarReference::operator < (const WUVarReference &other) const { - // **NOTE** variable and target names are enough to guarantee uniqueness - const bool hasTranspose = (getTransposeSynapseGroup() != nullptr); - const bool otherHasTranspose = (other.getTransposeSynapseGroup() != nullptr); - if (hasTranspose && otherHasTranspose) { - return (std::make_tuple(getVar().name, getTargetName(), getTransposeVar().name, getTransposeTargetName()) - < std::tuple(other.getVar().name, other.getTargetName(), other.getTransposeVar().name, other.getTransposeTargetName())); - } - else if (hasTranspose) { - return false; - } - else if (otherHasTranspose) { - return true; - } - else { - return (std::make_tuple(getVar().name, getTargetName()) - < std::make_tuple(other.getVar().name, other.getTargetName())); - } + return (std::tie(getVarName(), getTargetName(), getTransposeVarName(), getTransposeTargetName()) + < std::tie(other.getVarName(), other.getTargetName(), other.getTransposeVarName(), other.getTransposeTargetName())); } //------------------------------------------------------------------------ WUVarReference WUVarReference::createWUVarReference(SynapseGroup *sg, const std::string &varName, SynapseGroup *transposeSG, const std::string &transposeVarName) { + const auto *wum = sg->getWUModel(); + auto *sgInternal = static_cast(sg); + const auto var = wum->getVars()[wum->getVarIndex(varName)]; if(transposeSG) { - return WUVarReference(sg->getWUModel()->getVarIndex(varName), sg->getWUModel()->getVars(), - transposeSG->getWUModel()->getVarIndex(transposeVarName), transposeSG->getWUModel()->getVars(), - WURef{static_cast(sg), static_cast(transposeSG)}); + const auto *transposeWUM = transposeSG->getWUModel(); + return WUVarReference(WURef{sgInternal, static_cast(transposeSG), + var, transposeWUM->getVars()[transposeWUM->getVarIndex(transposeVarName)]}); } else { - return WUVarReference(sg->getWUModel()->getVarIndex(varName), sg->getWUModel()->getVars(), - WURef{static_cast(sg), static_cast(transposeSG)}); + return WUVarReference(WURef{static_cast(sg), nullptr, + var, std::nullopt}); } } //------------------------------------------------------------------------ WUVarReference WUVarReference::createWUVarReference(CustomUpdateWU *cu, const std::string &varName) { - return WUVarReference(cu->getCustomUpdateModel()->getVarIndex(varName), cu->getCustomUpdateModel()->getVars(), - CURef{static_cast(cu)}); + const auto *cum = cu->getCustomUpdateModel(); + return WUVarReference(CURef{static_cast(cu), + cum->getVars()[cum->getVarIndex(varName)]}); } //------------------------------------------------------------------------ WUVarReference WUVarReference::createWUVarReference(CustomConnectivityUpdate *ccu, const std::string &varName) { - return WUVarReference(ccu->getCustomConnectivityUpdateModel()->getVarIndex(varName), ccu->getCustomConnectivityUpdateModel()->getVars(), - CCURef{static_cast(ccu)}); + const auto *ccum = ccu->getCustomConnectivityUpdateModel(); + return WUVarReference(CCURef{static_cast(ccu), + ccum->getVars()[ccum->getVarIndex(varName)]}); } //------------------------------------------------------------------------ SynapseGroupInternal *WUVarReference::getSynapseGroupInternal() const @@ -280,48 +365,10 @@ SynapseGroupInternal *WUVarReference::getTransposeSynapseGroupInternal() const [](const auto&)->SynapseGroupInternal* { return nullptr; }}, m_Detail); } + //------------------------------------------------------------------------ -VarAccessDim WUVarReference::getVarDims(const Models::Base::Var &var) const -{ - const auto &varAccess = var.access; - return std::visit( - Utils::Overload{ - // If reference is to a custom update variable, - // remove dimensions from those of update - [&varAccess](const CURef &ref) - { - return clearDim(ref.group->getDims(), - varAccess.getDims()); - }, - // Otherwise, if reference is to the synaptic variables of a custom connectivity update, - // remove BATCH dimension as these are never batched - [&varAccess](const CCURef&) - { - return clearDim(varAccess.getDims(), - VarAccessDim::BATCH); - }, - // Otherwise, use dimensionality directly - [&varAccess](const WURef&){ return varAccess.getDims(); }}, - m_Detail); -} -//------------------------------------------------------------------------ -WUVarReference::WUVarReference(size_t varIndex, const Models::Base::VarVec &varVec, - const DetailType &detail) -: VarReferenceBase(varIndex, varVec), m_TransposeVarIndex(std::nullopt), - m_TransposeVar(std::nullopt), m_Detail(detail) -{ - // Check matrix types - auto *sg = getSynapseGroup(); - if(!(sg->getMatrixType() & SynapseMatrixWeight::INDIVIDUAL) && !(sg->getMatrixType() & SynapseMatrixWeight::KERNEL)) { - throw std::runtime_error("Only INDIVIDUAL or KERNEL weight update variables can be referenced."); - } -} -//------------------------------------------------------------------------ -WUVarReference::WUVarReference(size_t varIndex, const Models::Base::VarVec &varVec, - size_t transposeVarIndex, const Models::Base::VarVec &transposeVarVec, - const DetailType &detail) -: VarReferenceBase(varIndex, varVec), m_TransposeVarIndex(transposeVarIndex), - m_TransposeVar(transposeVarVec.at(transposeVarIndex)), m_Detail(detail) +WUVarReference::WUVarReference(const DetailType &detail) +: m_Detail(detail) { // Check matrix types auto *sg = getSynapseGroupInternal(); @@ -331,33 +378,35 @@ WUVarReference::WUVarReference(size_t varIndex, const Models::Base::VarVec &varV // Check that both tranpose and original group has individual variables auto *transposeSG = getTransposeSynapseGroupInternal(); - if(!(transposeSG->getMatrixType() & SynapseMatrixWeight::INDIVIDUAL) || !(sg->getMatrixType() & SynapseMatrixWeight::INDIVIDUAL)) { - throw std::runtime_error("Transpose updates can only reference INDIVIDUAL weight update variables."); - } + if(transposeSG) { + if(!(transposeSG->getMatrixType() & SynapseMatrixWeight::INDIVIDUAL) || !(sg->getMatrixType() & SynapseMatrixWeight::INDIVIDUAL)) { + throw std::runtime_error("Transpose updates can only reference INDIVIDUAL weight update variables."); + } - // Check that both the tranpose and main synapse groups have dense connectivity - if(!(transposeSG->getMatrixType() & SynapseMatrixConnectivity::DENSE) || !(sg->getMatrixType() & SynapseMatrixConnectivity::DENSE)) { - throw std::runtime_error("Tranpose updates can only be performed on DENSE weight update model variables."); - } + // Check that both the tranpose and main synapse groups have dense connectivity + if(!(transposeSG->getMatrixType() & SynapseMatrixConnectivity::DENSE) || !(sg->getMatrixType() & SynapseMatrixConnectivity::DENSE)) { + throw std::runtime_error("Tranpose updates can only be performed on DENSE weight update model variables."); + } - // Check that sizes of transpose and main synapse groups match - if((transposeSG->getSrcNeuronGroup()->getNumNeurons() != sg->getTrgNeuronGroup()->getNumNeurons()) - || (transposeSG->getTrgNeuronGroup()->getNumNeurons() != sg->getSrcNeuronGroup()->getNumNeurons())) - { - throw std::runtime_error("Transpose updates can only be performed on connections between appropriately sized neuron groups."); - } + // Check that sizes of transpose and main synapse groups match + if((transposeSG->getSrcNeuronGroup()->getNumNeurons() != sg->getTrgNeuronGroup()->getNumNeurons()) + || (transposeSG->getTrgNeuronGroup()->getNumNeurons() != sg->getSrcNeuronGroup()->getNumNeurons())) + { + throw std::runtime_error("Transpose updates can only be performed on connections between appropriately sized neuron groups."); + } - // Check types - // **NOTE** this is a bit over-conservative as, at this point, types are not resolved so "scalar" cannot be compared with "float" - if(getVar().type != getTransposeVar().type) { - throw std::runtime_error("Transpose updates can only be performed on variables with the same type"); - } + // Check types + // **NOTE** this is a bit over-conservative as, at this point, types are not resolved so "scalar" cannot be compared with "float" + if(getVarType() != getTransposeVarType()) { + throw std::runtime_error("Transpose updates can only be performed on variables with the same type"); + } - // Check duplicatedness of variables - if((getDims() & VarAccessDim::BATCH) - != (getTransposeDims() & VarAccessDim::BATCH)) - { - throw std::runtime_error("Transpose updates can only be performed on similarly batched variables"); + // Check duplicatedness of variables + if((getVarDims() & VarAccessDim::BATCH) + != (*getTransposeVarDims() & VarAccessDim::BATCH)) + { + throw std::runtime_error("Transpose updates can only be performed on similarly batched variables"); + } } } @@ -403,11 +452,25 @@ EGPReference EGPReference::createWUEGPRef(const SynapseGroup *sg, const std::str //---------------------------------------------------------------------------- // Free functions //---------------------------------------------------------------------------- -void updateHash(const Base::Var &v, boost::uuids::detail::sha1 &hash) +void updateHash(const Base::NeuronVar &v, boost::uuids::detail::sha1 &hash) { Utils::updateHash(v.name, hash); Type::updateHash(v.type, hash); - v.access.updateHash(hash); + Utils::updateHash(v.access, hash); +} +//---------------------------------------------------------------------------- +void updateHash(const Base::SynapseVar &v, boost::uuids::detail::sha1 &hash) +{ + Utils::updateHash(v.name, hash); + Type::updateHash(v.type, hash); + Utils::updateHash(v.access, hash); +} +//---------------------------------------------------------------------------- +void updateHash(const Base::CustomUpdateVar &v, boost::uuids::detail::sha1 &hash) +{ + Utils::updateHash(v.name, hash); + Type::updateHash(v.type, hash); + Utils::updateHash(v.access, hash); } //---------------------------------------------------------------------------- void updateHash(const Base::VarRef &v, boost::uuids::detail::sha1 &hash) @@ -426,17 +489,17 @@ void updateHash(const Base::EGPRef &e, boost::uuids::detail::sha1 &hash) void updateHash(const VarReference &v, boost::uuids::detail::sha1 &hash) { Utils::updateHash(v.getTargetName(), hash); - Utils::updateHash(v.getVarIndex(), hash); + Utils::updateHash(v.getVarName(), hash); } //---------------------------------------------------------------------------- void updateHash(const WUVarReference &v, boost::uuids::detail::sha1 &hash) { Utils::updateHash(v.getTargetName(), hash); - Utils::updateHash(v.getVarIndex(), hash); + Utils::updateHash(v.getVarName(), hash); if(v.getTransposeSynapseGroup() != nullptr) { Utils::updateHash(v.getTransposeTargetName(), hash); - Utils::updateHash(v.getTransposeVarIndex(), hash); + Utils::updateHash(v.getTransposeVarName(), hash); } } //---------------------------------------------------------------------------- From eae1e11d2edbf211780e9909d129c7a347261685 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 8 Sep 2023 17:26:00 +0100 Subject: [PATCH 47/60] fixed up code base --- .../backends/single_threaded_cpu/backend.h | 3 +- .../genn/genn/code_generator/backendBase.h | 7 ++- .../customConnectivityUpdateGroupMerged.h | 4 +- .../genn/genn/code_generator/environment.h | 51 ++++++++++--------- include/genn/genn/customUpdate.h | 8 +-- include/genn/genn/models.h | 4 +- include/genn/genn/varAccess.h | 14 +++-- .../backends/single_threaded_cpu/backend.cc | 4 +- .../customConnectivityUpdateGroupMerged.cc | 24 ++++----- .../code_generator/customUpdateGroupMerged.cc | 12 ++--- .../genn/code_generator/generateRunner.cc | 6 +-- .../code_generator/neuronUpdateGroupMerged.cc | 46 ++++++++--------- .../synapseUpdateGroupMerged.cc | 16 +++--- src/genn/genn/customConnectivityUpdate.cc | 10 ++-- src/genn/genn/customUpdate.cc | 4 +- src/genn/genn/customUpdateModels.cc | 8 --- src/genn/genn/models.cc | 24 ++++----- 17 files changed, 125 insertions(+), 120 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 93fd13577f..cea1e191ba 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -254,8 +254,7 @@ class BACKEND_EXPORT Backend : public BackendBase // If variable is a reduction target, copy value from register straight back into global memory if(v.access & VarAccessModeAttribute::REDUCE) { const std::string idx = env.getName(idxName); - const VarAccessDim varAccessDim = clearDim(cg.getArchetype().getDims(), - v.access.template getDims()); + const VarAccessDim varAccessDim = getAccessDim(v.access, cg.getArchetype().getDims()); env.getStream() << "group->" << v.name << "[" << cg.getVarIndex(1, varAccessDim, idx) << "] = " << env[v.name] << ";" << std::endl; } } diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 14edf81e8f..5b62b010fd 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -573,10 +573,9 @@ class GENN_EXPORT BackendBase // If variable is a reduction target, define variable initialised to correct initial value for reduction if (v.access & VarAccessModeAttribute::REDUCE) { const auto resolvedType = v.type.resolve(cg.getTypeContext()); - os << resolvedType.getName() << " _lr" << v.name << " = " << getReductionInitialValue(v.access, resolvedType) << ";" << std::endl; - const VarAccessDim varAccessDim = clearDim(cg.getArchetype().getDims(), - v.access.template getDims()); - reductionTargets.push_back({v.name, resolvedType, v.access, + os << resolvedType.getName() << " _lr" << v.name << " = " << getReductionInitialValue(getVarAccessMode(v.access), resolvedType) << ";" << std::endl; + const VarAccessDim varAccessDim = getAccessDim(v.access, cg.getArchetype().getDims()); + reductionTargets.push_back({v.name, resolvedType, getVarAccessMode(v.access), cg.getVarIndex(batchSize, varAccessDim, idx)}); } } diff --git a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h index 2f2d96080b..8d815968c1 100644 --- a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h @@ -80,7 +80,7 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public GroupMerged &(GroupInternal::*)(void) const; - using GetVarIndexFn = std::function; + + template + using AdapterDef = typename std::invoke_result_t::value_type; + + template + using GetVarIndexFn = std::function::AccessType, const std::string&)>; template using GetVarRefIndexFn = std::function; @@ -631,7 +636,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase - void addVars(const std::string &arrayPrefix, GetVarIndexFn getIndexFn, + void addVars(const std::string &arrayPrefix, GetVarIndexFn getIndexFn, const std::string &fieldSuffix = "", bool readOnly = false) { // Loop through variables @@ -654,7 +659,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase(arrayPrefix, [&indexSuffix](VarAccess, const std::string &) { return indexSuffix; }, + addVars(arrayPrefix, [&indexSuffix](typename AdapterDef::AccessType, const std::string &) { return indexSuffix; }, fieldSuffix, readOnly); } @@ -695,9 +700,6 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase>> m_Environment; }; -//! Type of a single definition -template -using Def = typename std::invoke_result_t::value_type; //------------------------------------------------------------------------ // GeNN::CodeGenerator::VarCachePolicy @@ -708,8 +710,10 @@ class VarCachePolicy { public: using GroupInternal = typename G::GroupInternal; - using GetIndexFn = std::function; - using ShouldAlwaysCopyFn = std::function; + using AdapterDef = typename std::invoke_result_t::value_type; + using AdapterAccess = typename AdapterDef::AccessType; + using GetIndexFn = std::function; + using ShouldAlwaysCopyFn = std::function; VarCachePolicy(GetIndexFn getReadIndex, GetIndexFn getWriteIndex, ShouldAlwaysCopyFn shouldAlwaysCopy = ShouldAlwaysCopyFn()) @@ -725,7 +729,7 @@ class VarCachePolicy //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - bool shouldAlwaysCopy(G&, const Def &var) const + bool shouldAlwaysCopy(G&, const AdapterDef &var) const { if(m_ShouldAlwaysCopy) { return m_ShouldAlwaysCopy(var.name, var.access); @@ -735,17 +739,17 @@ class VarCachePolicy } } - std::string getReadIndex(G&, const Def &var) const + std::string getReadIndex(G&, const AdapterDef &var) const { return m_GetReadIndex(var.name, var.access); } - std::string getWriteIndex(G&, const Def &var) const + std::string getWriteIndex(G&, const AdapterDef &var) const { return m_GetWriteIndex(var.name, var.access); } - std::string getTargetName(const GroupInternal &g, const Def &var) const + std::string getTargetName(const GroupInternal &g, const AdapterDef &var) const { return var.name + A(g).getNameSuffix(); } @@ -768,6 +772,7 @@ class VarRefCachePolicy { protected: using GroupInternal = typename G::GroupInternal; + using AdapterDef = typename std::invoke_result_t::value_type; using GetIndexFn = std::function; VarRefCachePolicy(GetIndexFn getReadIndex, GetIndexFn getWriteIndex) @@ -781,27 +786,27 @@ class VarRefCachePolicy //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - bool shouldAlwaysCopy(G&, const Def&) const + bool shouldAlwaysCopy(G&, const AdapterDef&) const { // **NOTE** something else is managing the actual variables // and is therefore responsible for copying between delay slots etc return false; } - std::string getReadIndex(G &g, const Def &var) const + std::string getReadIndex(G &g, const AdapterDef &var) const { return m_GetReadIndex(var.name, A(g.getArchetype()).getInitialisers().at(var.name)); } - std::string getWriteIndex(G &g, const Def &var) const + std::string getWriteIndex(G &g, const AdapterDef &var) const { return m_GetWriteIndex(var.name, A(g.getArchetype()).getInitialisers().at(var.name)); } - std::string getTargetName(const GroupInternal &g, const Def &var) const + std::string getTargetName(const GroupInternal &g, const AdapterDef &var) const { const auto &initialiser = A(g).getInitialisers().at(var.name); - return initialiser.getVar().name + initialiser.getTargetName(); + return initialiser.getVarName() + initialiser.getTargetName(); } private: @@ -819,9 +824,9 @@ class VarRefCachePolicy template class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P { - - public: + using AdapterDef = typename std::invoke_result_t::value_type; + template EnvironmentLocalCacheBase(G &group, F &fieldGroup, const Type::TypeContext &context, EnvironmentExternalBase &enclosing, const std::string &arrayPrefix, const std::string &fieldSuffix, const std::string &localPrefix, @@ -844,7 +849,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P // Copy definitions of variables which have been referenced into new vector or all if always copy set const auto varDefs = archetypeAdapter.getDefs(); - std::vector> referencedDefs; + std::vector referencedDefs; std::copy_if(varDefs.cbegin(), varDefs.cend(), std::back_inserter(referencedDefs), [this](const auto &v) { @@ -865,7 +870,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P return arrayPrefix + this->getTargetName(group.getGroups().at(i), v); }); - if(v.access == VarAccessMode::READ_ONLY) { + if(getVarAccessMode(v.access) == VarAccessMode::READ_ONLY) { getContextStream() << "const "; } getContextStream() << resolvedType.getName() << " _" << m_LocalPrefix << v.name; @@ -885,7 +890,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P // Loop through referenced definitions again for(const auto &v : referencedDefs) { // If we should always copy variable or variable is read-write - if(this->shouldAlwaysCopy(m_Group.get(), v) || (v.access == VarAccessMode::READ_WRITE)) { + if(this->shouldAlwaysCopy(m_Group.get(), v) || (getVarAccessMode(v.access) == VarAccessMode::READ_WRITE)) { getContextStream() << "group->" << v.name << m_FieldSuffix << "[" << printSubs(this->getWriteIndex(m_Group.get(), v), *this) << "]"; getContextStream() << " = _" << m_LocalPrefix << v.name << ";" << std::endl; } @@ -951,7 +956,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P std::string m_ArrayPrefix; std::string m_FieldSuffix; std::string m_LocalPrefix; - std::unordered_map> m_VariablesReferenced; + std::unordered_map> m_VariablesReferenced; }; template diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index 30fa9df2bf..7eace9d364 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -96,7 +96,7 @@ class GENN_EXPORT CustomUpdateBase [reduceDim](const Models::Base::CustomUpdateVar &v) { return ((v.access & VarAccessModeAttribute::REDUCE) - && (v.access & reduceDim)); + && (static_cast(v.access) & static_cast(reduceDim))); })) { return true; @@ -124,7 +124,7 @@ class GENN_EXPORT CustomUpdateBase // Loop through variable references and or together their dimensions to get dimensionality of update m_Dims = VarAccessDim{0}; for(const auto &v : varRefs) { - m_Dims = m_Dims | v.second.getDims(); + m_Dims = m_Dims | v.second.getVarDims(); } // Loop through all variable references @@ -133,7 +133,7 @@ class GENN_EXPORT CustomUpdateBase // If the shape of the references variable doesn't match the dimensionality // of the custom update, check its access mode isn't read-write - if((m_Dims != varRef.getDims()) + if((m_Dims != varRef.getVarDims()) && (modelVarRef.access == VarAccessMode::READ_WRITE)) { throw std::runtime_error("Variable references to lower-dimensional variables cannot be read-write."); @@ -210,7 +210,7 @@ class CustomUpdateVarAdapter VarAccessDim getVarDims(const Models::Base::CustomUpdateVar &var) const { - getAccessDim(var.access, m_CU.getDims()); + return getAccessDim(var.access, m_CU.getDims()); } private: diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index f764a06f01..4ab2b4711b 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -63,6 +63,8 @@ class GENN_EXPORT Base : public Snippet::Base VarBase(const std::string &n, const std::string &t, A a) : name(n), type(t), access(a) {} + + using AccessType = A; bool operator == (const VarBase &other) const { @@ -393,7 +395,7 @@ void checkVarReferenceTypes(const std::unordered_map &varRefs, c // Check types of variable references against those specified in model // **THINK** this is rather conservative but I think not allowing "scalar" and whatever happens to be scalar type is ok - if(varRef.getVar().type != modelVarRef.type) { + if(varRef.getVarType() != modelVarRef.type) { throw std::runtime_error("Incompatible type for variable reference '" + modelVarRef.name + "'"); } } diff --git a/include/genn/genn/varAccess.h b/include/genn/genn/varAccess.h index 280b031652..28faaed309 100644 --- a/include/genn/genn/varAccess.h +++ b/include/genn/genn/varAccess.h @@ -114,17 +114,20 @@ inline bool operator & (CustomUpdateVarAccess mode, VarAccessModeAttribute modeA { return (static_cast(mode) & static_cast(modeAttribute)) != 0; } + inline bool operator & (VarAccessDim a, VarAccessDim b) { return (static_cast(a) & static_cast(b)) != 0; } -/*inline VarAccessDim operator | (VarAccessDim a, VarAccessDim b) +inline VarAccessDim operator | (VarAccessDim a, VarAccessDim b) { return static_cast(static_cast(a) | static_cast(b)); -}*/ - +} +//---------------------------------------------------------------------------- +// Free functions +//---------------------------------------------------------------------------- inline VarAccessDim clearDim(VarAccessDim a, VarAccessDim b) { return static_cast(static_cast(a) & ~static_cast(b)); @@ -145,6 +148,11 @@ inline VarAccessDim getAccessDim(CustomUpdateVarAccess v, VarAccessDim popDims) return clearDim(popDims, static_cast(static_cast(v) & ~0x1F)); } +inline VarAccessMode getVarAccessMode(VarAccessMode v) +{ + return v; +} + inline VarAccessMode getVarAccessMode(NeuronVarAccess v) { return static_cast(static_cast(v) & 0x1F); diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 826327071b..5a225c89b8 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -2018,7 +2018,7 @@ void Backend::genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateG [&cg](const Models::VarReference &varRef, const std::string &index) { return cg.getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, 1, - varRef.getDims(), index); + varRef.getVarDims(), index); }); } //-------------------------------------------------------------------------- @@ -2028,7 +2028,7 @@ void Backend::genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateW env, cg, idxName, [&cg](const Models::WUVarReference &varRef, const std::string &index) { - return cg.getVarRefIndex(1, varRef.getDims(), index); + return cg.getVarRefIndex(1, varRef.getVarDims(), index); }); } } // namespace GeNN::CodeGenerator::SingleThreadedCPU diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 1f68b58e53..b1d92d1723 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -36,12 +36,12 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t dependentVarsList.sort([](const auto &a, const auto &b) { boost::uuids::detail::sha1 hashA; - Type::updateHash(a.getVar().type, hashA); - Utils::updateHash(a.getDims(), hashA); + Type::updateHash(a.getVarType(), hashA); + Utils::updateHash(a.getVarDims(), hashA); boost::uuids::detail::sha1 hashB; - Type::updateHash(b.getVar().type, hashB); - Utils::updateHash(b.getDims(), hashB); + Type::updateHash(b.getVarType(), hashB); + Utils::updateHash(b.getVarDims(), hashB); return (hashA.get_digest() < hashB.get_digest()); }); @@ -129,7 +129,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back [&backend, v](const auto &g, size_t) { const auto varRef = g.getPreVarReferences().at(v.name); - return backend.getDeviceVarPrefix() + varRef.getVar().name + varRef.getTargetName(); + return backend.getDeviceVarPrefix() + varRef.getVarName() + varRef.getTargetName(); }, index); } @@ -144,12 +144,12 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Add private fields for dependent variables for(size_t i = 0; i < getSortedArchetypeDependentVars().size(); i++) { - auto resolvedType = getSortedArchetypeDependentVars().at(i).getVar().type.resolve(getTypeContext()); + auto resolvedType = getSortedArchetypeDependentVars().at(i).getVarType().resolve(getTypeContext()); updateEnv.addField(resolvedType.createPointer(), "_dependent_var_" + std::to_string(i), "dependentVar" + std::to_string(i), [i, &backend, this](const auto&, size_t g) { const auto &varRef = m_SortedDependentVars[g][i]; - return backend.getDeviceVarPrefix() + varRef.getVar().name + varRef.getTargetName(); + return backend.getDeviceVarPrefix() + varRef.getVarName() + varRef.getTargetName(); }); } @@ -186,7 +186,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back for (size_t i = 0; i < ccuVarRefs.size(); i++) { // If model is batched and this variable is duplicated const auto &varRef = getArchetype().getVarReferences().at(ccuVarRefs[i].name); - if (batchSize > 1 && (varRef.getDims() & VarAccessDim::BATCH)) { + if (batchSize > 1 && (varRef.getVarDims() & VarAccessDim::BATCH)) { // Copy parameter into a register (just incase it's e.g. a RNG call) and copy into all batches addSynapse << "const " << ccuVarRefs[i].type.resolve(getTypeContext()).getName() << " _" << ccuVarRefs[i].name << "Val = $(" << (1 + ccuVars.size() + i) << ");" << std::endl; addSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; @@ -206,7 +206,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Loop through any other dependent variables for (size_t i = 0; i < dependentVars.size(); i++) { // If model is batched and this dependent variable is duplicated - if (batchSize > 1 && (dependentVars.at(i).getDims() & VarAccessDim::BATCH)) { + if (batchSize > 1 && (dependentVars.at(i).getVarDims() & VarAccessDim::BATCH)) { // Loop through all batches and zero addSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; { @@ -219,7 +219,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back addSynapse << "$(_dependent_var_" << i << ")[newIdx] = 0;" << std::endl; } - addSynapseTypes.push_back(dependentVars.at(i).getVar().type.resolve(getTypeContext())); + addSynapseTypes.push_back(dependentVars.at(i).getVarType().resolve(getTypeContext())); } // Increment row length @@ -251,7 +251,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back for (size_t i = 0; i < ccuVarRefs.size(); i++) { // If model is batched and this variable is duplicated const auto &varRef = getArchetype().getVarReferences().at(ccuVarRefs[i].name); - if (batchSize > 1 && (varRef.getDims() & VarAccessDim::BATCH)) { + if (batchSize > 1 && (varRef.getVarDims() & VarAccessDim::BATCH)) { // Loop through all batches and copy custom connectivity update variable references from end of row over synapse to be deleted removeSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; { @@ -269,7 +269,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Loop through any other dependent variables for (size_t i = 0; i < dependentVars.size(); i++) { // If model is batched and this dependent variable is duplicated - if (batchSize > 1 && (dependentVars.at(i).getDims() & VarAccessDim::BATCH)) { + if (batchSize > 1 && (dependentVars.at(i).getVarDims() & VarAccessDim::BATCH)) { // Loop through all batches and copy dependent variable from end of row over synapse to be deleted removeSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; { diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 60b55669f8..2aa9e70973 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -60,9 +60,9 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", - [this, batchSize, &cuEnv](const std::string&, VarAccess d) + [this, batchSize, &cuEnv](const std::string&, CustomUpdateVarAccess d) { - return getVarIndex(batchSize, clearDim(getArchetype().getDims(), d.getDims()), "$(id)"); + return getVarIndex(batchSize, getAccessDim(d, getArchetype().getDims()), "$(id)"); }); // Create an environment which caches variable references in local variables if they are accessed @@ -187,9 +187,9 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", - [this, batchSize, &cuEnv](const std::string&, VarAccess d) + [this, batchSize, &cuEnv](const std::string&, CustomUpdateVarAccess d) { - return getVarIndex(batchSize, clearDim(getArchetype().getDims(), d.getDims()), "$(id_syn)"); + return getVarIndex(batchSize, getAccessDim(d, getArchetype().getDims()), "$(id_syn)"); }); // Create an environment which caches variable references in local variables if they are accessed @@ -197,7 +197,7 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back *this, *this, getTypeContext(), varEnv, backend.getDeviceVarPrefix(), "", "l", [this, batchSize, &varEnv](const std::string&, const Models::WUVarReference &v) { - return getVarRefIndex(batchSize, v.getDims(), "$(id_syn)"); + return getVarRefIndex(batchSize, v.getVarDims(), "$(id_syn)"); }); Transpiler::ErrorHandler errorHandler("Custom update '" + getArchetype().getName() + "' update code"); @@ -242,7 +242,7 @@ std::string CustomUpdateTransposeWUGroupMerged::addTransposeField(const BackendB [&backend, v](const auto &g, size_t) { const auto varRef = g.getVarReferences().at(v.name); - return backend.getDeviceVarPrefix() + varRef.getTransposeVar().name + varRef.getTransposeTargetName(); + return backend.getDeviceVarPrefix() + *varRef.getTransposeVarName() + *varRef.getTransposeTargetName(); }); // Return name of transpose variable diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 4100aae14d..4aab494088 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -1110,8 +1110,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, std::vector neuronStatePushPullFunctions; for(const auto &var : neuronModel->getVars()) { const auto &varInit = n.second.getVarInitialisers().at(var.name); - const unsigned int numCopies = getNumVarCopies(var.access.getDims(), batchSize); - const unsigned int numElements = getNumNeuronVarElements(var.access.getDims(), n.second.getNumNeurons()); + const unsigned int numCopies = getNumVarCopies(getAccessDim(var.access), batchSize); + const unsigned int numElements = getNumNeuronVarElements(getAccessDim(var.access), n.second.getNumNeurons()); const size_t count = n.second.isVarQueueRequired(var.name) ? numCopies * numElements * n.second.getNumDelaySlots() : numCopies * numElements; const bool autoInitialized = !Utils::areTokensEmpty(varInit.getCodeTokens()); const auto resolvedType = var.type.resolve(modelMerged.getModel().getTypeContext()); @@ -1429,7 +1429,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const bool autoInitialized = !Utils::areTokensEmpty(varInit.getCodeTokens()); const auto resolvedType = wuVar.type.resolve(modelMerged.getModel().getTypeContext()); if(individualWeights || kernelWeights) { - const size_t size = getSynapseVarSize(wuVar.access.getDims(), + const size_t size = getSynapseVarSize(getAccessDim(wuVar.access), backend, s.second, batchSize); genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index a1a60f0692..67191e3326 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -40,9 +40,9 @@ void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), csEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", - [batchSize, &ng](const std::string&, VarAccess d) + [batchSize, &ng](const std::string&, NeuronVarAccess d) { - return ng.getVarIndex(batchSize, d.getDims(), "$(id)"); + return ng.getVarIndex(batchSize, getAccessDim(d), "$(id)"); }); // Pretty print code back to environment @@ -121,9 +121,9 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), psmEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", - [batchSize, &ng](const std::string&, VarAccess d) + [batchSize, &ng](const std::string&, NeuronVarAccess d) { - return ng.getVarIndex(batchSize, d.getDims(), "$(id)"); + return ng.getVarIndex(batchSize, getAccessDim(d), "$(id)"); }); // Pretty print code back to environment @@ -202,15 +202,15 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back const bool delayed = (getArchetype().getBackPropDelaySteps() != NO_DELAY); EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", - [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccess d) + [batchSize, delayed, &synEnv, &ng](const std::string&, NeuronVarAccess d) { - return ng.getReadVarIndex(delayed, batchSize, d.getDims(), "$(id)"); + return ng.getReadVarIndex(delayed, batchSize, getAccessDim(d), "$(id)"); }, - [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccess d) + [batchSize, delayed, &synEnv, &ng](const std::string&, NeuronVarAccess d) { - return ng.getWriteVarIndex(delayed, batchSize, d.getDims(), "$(id)"); + return ng.getWriteVarIndex(delayed, batchSize, getAccessDim(d), "$(id)"); }, - [delayed](const std::string&, VarAccess) + [delayed](const std::string&, NeuronVarAccess) { return delayed; }); @@ -242,8 +242,8 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::genCopyDelayedVars(EnvironmentEx // Loop through variables and copy between read and write delay slots // **YUCK** this a bit sketchy as fields may not have been added - could add fields here but need to guarantee uniqueness for(const auto &v : getArchetype().getWUModel()->getPostVars()) { - if(v.access == VarAccessMode::READ_WRITE) { - const VarAccessDim varDims = v.access.getDims(); + if(getVarAccessMode(v.access) == VarAccessMode::READ_WRITE) { + const VarAccessDim varDims = getAccessDim(v.access); env.print("group->" + v.name + suffix + "[" + ng.getWriteVarIndex(true, batchSize, varDims, "$(id)") + "] = "); env.printLine("group->" + v.name + suffix + "[" + ng.getReadVarIndex(true, batchSize, varDims, "$(id)") + "];"); } @@ -294,15 +294,15 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back const bool delayed = (getArchetype().getDelaySteps() != NO_DELAY); EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", - [batchSize, delayed, &ng](const std::string&, VarAccess d) + [batchSize, delayed, &ng](const std::string&, NeuronVarAccess d) { - return ng.getReadVarIndex(delayed, batchSize, d.getDims(), "$(id)"); + return ng.getReadVarIndex(delayed, batchSize, getAccessDim(d), "$(id)"); }, - [batchSize, delayed, &ng](const std::string&, VarAccess d) + [batchSize, delayed, &ng](const std::string&, NeuronVarAccess d) { - return ng.getWriteVarIndex(delayed, batchSize, d.getDims(), "$(id)"); + return ng.getWriteVarIndex(delayed, batchSize, getAccessDim(d), "$(id)"); }, - [delayed](const std::string&, VarAccess) + [delayed](const std::string&, NeuronVarAccess) { return delayed; }); @@ -334,8 +334,8 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::genCopyDelayedVars(EnvironmentEx // Loop through variables and copy between read and write delay slots // **YUCK** this a bit sketchy as fields may not have been added - could add fields here but need to guarantee uniqueness for(const auto &v : getArchetype().getWUModel()->getPreVars()) { - if(v.access == VarAccessMode::READ_WRITE) { - const VarAccessDim varDims = v.access.getDims(); + if(getVarAccessMode(v.access) == VarAccessMode::READ_WRITE) { + const VarAccessDim varDims = getAccessDim(v.access); env.print("group->" + v.name + suffix + "[" + ng.getWriteVarIndex(true, batchSize, varDims, "$(id)") + "] = "); env.printLine("group->" + v.name + suffix + "[" + ng.getReadVarIndex(true, batchSize, varDims, "$(id)") + "];"); } @@ -512,17 +512,17 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // **NOTE** always copy variables if variable is delayed EnvironmentLocalVarCache neuronVarEnv( *this, *this, getTypeContext(), neuronEnv, backend.getDeviceVarPrefix(), "", "l", - [batchSize, &neuronEnv, this](const std::string &varName, VarAccess d) + [batchSize, &neuronEnv, this](const std::string &varName, NeuronVarAccess d) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); - return getReadVarIndex(delayed, batchSize, d.getDims(), "$(id)") ; + return getReadVarIndex(delayed, batchSize, getAccessDim(d), "$(id)") ; }, - [batchSize, &neuronEnv, this](const std::string &varName, VarAccess d) + [batchSize, &neuronEnv, this](const std::string &varName, NeuronVarAccess d) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); - return getWriteVarIndex(delayed, batchSize, d.getDims(), "$(id)") ; + return getWriteVarIndex(delayed, batchSize, getAccessDim(d), "$(id)") ; }, - [this](const std::string &varName, VarAccess) + [this](const std::string &varName, NeuronVarAccess) { return (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); }); diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 0772e46ddf..34fb123854 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -30,15 +30,15 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa // Substitute names of pre and postsynaptic weight update variable synEnv.template addVars( backend.getDeviceVarPrefix(), - [&sg, batchSize](VarAccess a, const std::string&) + [&sg, batchSize](NeuronVarAccess a, const std::string&) { - return sg.getPreWUVarIndex(batchSize, a.getDims(), "$(id_pre)"); + return sg.getPreWUVarIndex(batchSize, getAccessDim(a), "$(id_pre)"); }, "", true); synEnv.template addVars( backend.getDeviceVarPrefix(), - [&sg, batchSize](VarAccess a, const std::string&) + [&sg, batchSize](NeuronVarAccess a, const std::string&) { - return sg.getPostWUVarIndex(batchSize, a.getDims(), "$(id_post)"); + return sg.getPostWUVarIndex(batchSize, getAccessDim(a), "$(id_post)"); }, "", true); @@ -78,9 +78,9 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa if (sg.getArchetype().getMatrixType() & SynapseMatrixWeight::INDIVIDUAL) { synEnv.template addVars( backend.getDeviceVarPrefix(), - [&sg, batchSize](VarAccess a, const std::string&) + [&sg, batchSize](SynapseVarAccess a, const std::string&) { - return sg.getSynVarIndex(batchSize, a.getDims(), "$(id_syn)"); + return sg.getSynVarIndex(batchSize, getAccessDim(a), "$(id_syn)"); }); } // Otherwise, if weights are procedual @@ -121,9 +121,9 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa synEnv.template addVars( backend.getDeviceVarPrefix(), - [&sg, batchSize](VarAccess a, const std::string&) + [&sg, batchSize](SynapseVarAccess a, const std::string&) { - return sg.getKernelVarIndex(batchSize, a.getDims(), "$(id_kernel)"); + return sg.getKernelVarIndex(batchSize, getAccessDim(a), "$(id_kernel)"); }); } diff --git a/src/genn/genn/customConnectivityUpdate.cc b/src/genn/genn/customConnectivityUpdate.cc index fdee52420c..37d62991fe 100644 --- a/src/genn/genn/customConnectivityUpdate.cc +++ b/src/genn/genn/customConnectivityUpdate.cc @@ -180,7 +180,7 @@ void CustomConnectivityUpdate::finalise(double dt, unsigned int batchSize) if (std::any_of(getPreVarReferences().cbegin(), getPreVarReferences().cend(), [](const auto &v) { - return (v.second.getDims() & VarAccessDim::BATCH); + return (v.second.getVarDims() & VarAccessDim::BATCH); })) { throw std::runtime_error("Presynaptic variables referenced by CustomConnectivityUpdate must be SHARED across batches"); @@ -190,7 +190,7 @@ void CustomConnectivityUpdate::finalise(double dt, unsigned int batchSize) if (std::any_of(getPostVarReferences().cbegin(), getPostVarReferences().cend(), [](const auto &v) { - return (v.second.getDims() & VarAccessDim::BATCH); + return (v.second.getVarDims() & VarAccessDim::BATCH); })) { throw std::runtime_error("Postsynaptic variables referenced by CustomConnectivityUpdate must be SHARED across batches"); @@ -312,8 +312,8 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdate::getHashDigest( [](const Models::WUVarReference &v) { boost::uuids::detail::sha1 hash; - Type::updateHash(v.getVar().type, hash); - Utils::updateHash(v.getDims(), hash); + Type::updateHash(v.getVarType(), hash); + Utils::updateHash(v.getVarDims(), hash); return hash.get_digest(); }); @@ -329,7 +329,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdate::getHashDigest( // Update hash with duplication mode of synaptic variable references for(const auto &v : getVarReferences()) { - Utils::updateHash(v.second.getDims(), hash); + Utils::updateHash(v.second.getVarDims(), hash); } return hash.get_digest(); diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index 41ee1456b7..dee79c24b3 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -192,7 +192,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdate::getHashDigest() const Utils::updateHash((v.second.getDelayNeuronGroup() == nullptr), hash); // Update hash with target variable dimensions as this effects indexing code - Utils::updateHash(v.second.getDims(), hash); + Utils::updateHash(v.second.getVarDims(), hash); } return hash.get_digest(); } @@ -299,7 +299,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWU::getHashDigest() const Utils::updateHash((v.second.getTransposeSynapseGroup() == nullptr), hash); // Update hash with dimensionality of target variable dimensions as this effects indexing code - Utils::updateHash(v.second.getDims(), hash); + Utils::updateHash(v.second.getVarDims(), hash); } return hash.get_digest(); diff --git a/src/genn/genn/customUpdateModels.cc b/src/genn/genn/customUpdateModels.cc index 984c96e9ff..6d26b18ba9 100644 --- a/src/genn/genn/customUpdateModels.cc +++ b/src/genn/genn/customUpdateModels.cc @@ -69,13 +69,5 @@ void Base::validate(const std::unordered_map ¶mValues, // Validate variable reference initialisers Utils::validateInitialisers(varRefs, varRefTargets, "Variable reference", description); Utils::validateVecNames(getExtraGlobalParamRefs(), "Extra global parameter reference"); - - // If any variables have an invalid access mode, give an error - const auto vars = getVars(); - if(std::any_of(vars.cbegin(), vars.cend(), - [](const Models::Base::Var &v){ return !v.access.template isValid(); })) - { - throw std::runtime_error("Custom update model variables must have CustomUpdateVarAccess access type"); - } } } // namespace GeNN::CustomUpdateModels \ No newline at end of file diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index bd1ae5e25b..84f6f46610 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -23,14 +23,14 @@ Base::EGPRef::EGPRef(const std::string &n, const std::string &t) const std::string &VarReference::getVarName() const { return std::visit( - Utils::Overload{[](const auto &ref){ return ref.var.name; }}, + Utils::Overload{[](const auto &ref){ return std::cref(ref.var.name); }}, m_Detail); } //---------------------------------------------------------------------------- const Type::UnresolvedType &VarReference::getVarType() const { return std::visit( - Utils::Overload{[](const auto &ref){ return ref.var.type; }}, + Utils::Overload{[](const auto &ref){ return std::cref(ref.var.type); }}, m_Detail); } //---------------------------------------------------------------------------- @@ -97,10 +97,10 @@ const std::string &VarReference::getTargetName() const { return std::visit( Utils::Overload{ - [](const PSMRef &ref) { return ref.group->getFusedPSVarSuffix(); }, - [](const WUPreRef &ref) { return ref.group->getFusedWUPreVarSuffix(); }, - [](const WUPostRef &ref) { return ref.group->getFusedWUPostVarSuffix(); }, - [](const auto &ref) { return ref.group->getName(); }}, + [](const PSMRef &ref) { return std::cref(ref.group->getFusedPSVarSuffix()); }, + [](const WUPreRef &ref) { return std::cref(ref.group->getFusedWUPreVarSuffix()); }, + [](const WUPostRef &ref) { return std::cref(ref.group->getFusedWUPostVarSuffix()); }, + [](const auto &ref) { return std::cref(ref.group->getName()); }}, m_Detail); } //---------------------------------------------------------------------------- @@ -191,14 +191,14 @@ VarReference VarReference::createWUPostVarRef(SynapseGroup *sg, const std::strin const std::string &WUVarReference::getVarName() const { return std::visit( - Utils::Overload{[](const auto &ref){ return ref.var.name; }}, + Utils::Overload{[](const auto &ref){ return std::cref(ref.var.name); }}, m_Detail); } //---------------------------------------------------------------------------- const Type::UnresolvedType &WUVarReference::getVarType() const { return std::visit( - Utils::Overload{[](const auto &ref){ return ref.var.type; }}, + Utils::Overload{[](const auto &ref){ return std::cref(ref.var.type); }}, m_Detail); } //---------------------------------------------------------------------------- @@ -226,7 +226,7 @@ VarAccessDim WUVarReference::getVarDims() const const std::string &WUVarReference::getTargetName() const { return std::visit( - Utils::Overload{[](const auto &ref) { return ref.group->getName(); }}, + Utils::Overload{[](const auto &ref) { return std::cref(ref.group->getName()); }}, m_Detail); } //---------------------------------------------------------------------------- @@ -248,7 +248,7 @@ std::optional WUVarReference::getTransposeVarName() const return std::nullopt; } }, - [](const auto&){ return std::nullopt; }}, + [](const auto&)->std::optional{ return std::nullopt; }}, m_Detail); } //------------------------------------------------------------------------ @@ -265,7 +265,7 @@ std::optional WUVarReference::getTransposeVarType() const return std::nullopt; } }, - [](const auto&){ return std::nullopt; }}, + [](const auto&)->std::optional{ return std::nullopt; }}, m_Detail); } //------------------------------------------------------------------------ @@ -282,7 +282,7 @@ std::optional WUVarReference::getTransposeVarDims() const return std::nullopt; } }, - [](const auto&){ return std::nullopt; }}, + [](const auto&)->std::optional{ return std::nullopt; }}, m_Detail); } //------------------------------------------------------------------------ From 818b283f8abe18e12215ec0866caab4969f943f1 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 8 Sep 2023 17:28:53 +0100 Subject: [PATCH 48/60] rename getAcccessDim to getVarAcccessDim for consistency --- .../backends/single_threaded_cpu/backend.h | 2 +- .../genn/genn/code_generator/backendBase.h | 2 +- include/genn/genn/currentSourceInternal.h | 2 +- .../genn/customConnectivityUpdateInternal.h | 6 +++--- include/genn/genn/customUpdate.h | 2 +- include/genn/genn/neuronGroupInternal.h | 2 +- include/genn/genn/synapseGroupInternal.h | 8 ++++---- include/genn/genn/varAccess.h | 10 +++++----- .../code_generator/customUpdateGroupMerged.cc | 4 ++-- .../genn/code_generator/generateRunner.cc | 6 +++--- .../code_generator/neuronUpdateGroupMerged.cc | 20 +++++++++---------- .../synapseUpdateGroupMerged.cc | 8 ++++---- src/genn/genn/models.cc | 16 +++++++-------- 13 files changed, 44 insertions(+), 44 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index cea1e191ba..47523d47f3 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -254,7 +254,7 @@ class BACKEND_EXPORT Backend : public BackendBase // If variable is a reduction target, copy value from register straight back into global memory if(v.access & VarAccessModeAttribute::REDUCE) { const std::string idx = env.getName(idxName); - const VarAccessDim varAccessDim = getAccessDim(v.access, cg.getArchetype().getDims()); + const VarAccessDim varAccessDim = getVarAccessDim(v.access, cg.getArchetype().getDims()); env.getStream() << "group->" << v.name << "[" << cg.getVarIndex(1, varAccessDim, idx) << "] = " << env[v.name] << ";" << std::endl; } } diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 5b62b010fd..13304e37f8 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -574,7 +574,7 @@ class GENN_EXPORT BackendBase if (v.access & VarAccessModeAttribute::REDUCE) { const auto resolvedType = v.type.resolve(cg.getTypeContext()); os << resolvedType.getName() << " _lr" << v.name << " = " << getReductionInitialValue(getVarAccessMode(v.access), resolvedType) << ";" << std::endl; - const VarAccessDim varAccessDim = getAccessDim(v.access, cg.getArchetype().getDims()); + const VarAccessDim varAccessDim = getVarAccessDim(v.access, cg.getArchetype().getDims()); reductionTargets.push_back({v.name, resolvedType, getVarAccessMode(v.access), cg.getVarIndex(batchSize, varAccessDim, idx)}); } diff --git a/include/genn/genn/currentSourceInternal.h b/include/genn/genn/currentSourceInternal.h index 7d9ed39d2a..facc48daa2 100644 --- a/include/genn/genn/currentSourceInternal.h +++ b/include/genn/genn/currentSourceInternal.h @@ -54,7 +54,7 @@ class CurrentSourceVarAdapter const std::string &getNameSuffix() const{ return m_CS.getName(); } - VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getAccessDim(var.access); } + VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getVarAccessDim(var.access); } private: //---------------------------------------------------------------------------- diff --git a/include/genn/genn/customConnectivityUpdateInternal.h b/include/genn/genn/customConnectivityUpdateInternal.h index 9d680ad1b5..557b057a5d 100644 --- a/include/genn/genn/customConnectivityUpdateInternal.h +++ b/include/genn/genn/customConnectivityUpdateInternal.h @@ -61,7 +61,7 @@ class CustomConnectivityUpdateVarAdapter const std::string &getNameSuffix() const{ return m_CU.getName(); } - VarAccessDim getVarDims(const Models::Base::SynapseVar &var) const{ return getAccessDim(var.access); } + VarAccessDim getVarDims(const Models::Base::SynapseVar &var) const{ return getVarAccessDim(var.access); } private: //---------------------------------------------------------------------------- @@ -92,7 +92,7 @@ class CustomConnectivityUpdatePreVarAdapter const std::string &getNameSuffix() const{ return m_CU.getName(); } - VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getAccessDim(var.access); } + VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getVarAccessDim(var.access); } private: //---------------------------------------------------------------------------- @@ -123,7 +123,7 @@ class CustomConnectivityUpdatePostVarAdapter const std::string &getNameSuffix() const{ return m_CU.getName(); } - VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getAccessDim(var.access); } + VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getVarAccessDim(var.access); } private: //---------------------------------------------------------------------------- diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index 7eace9d364..6365ea3b35 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -210,7 +210,7 @@ class CustomUpdateVarAdapter VarAccessDim getVarDims(const Models::Base::CustomUpdateVar &var) const { - return getAccessDim(var.access, m_CU.getDims()); + return getVarAccessDim(var.access, m_CU.getDims()); } private: diff --git a/include/genn/genn/neuronGroupInternal.h b/include/genn/genn/neuronGroupInternal.h index 9eac87d979..3f0f466cbd 100644 --- a/include/genn/genn/neuronGroupInternal.h +++ b/include/genn/genn/neuronGroupInternal.h @@ -78,7 +78,7 @@ class NeuronVarAdapter const std::string &getNameSuffix() const{ return m_NG.getName(); } - VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getAccessDim(var.access); } + VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getVarAccessDim(var.access); } private: //---------------------------------------------------------------------------- diff --git a/include/genn/genn/synapseGroupInternal.h b/include/genn/genn/synapseGroupInternal.h index 17947d6dd3..99e436bd90 100644 --- a/include/genn/genn/synapseGroupInternal.h +++ b/include/genn/genn/synapseGroupInternal.h @@ -118,7 +118,7 @@ class SynapsePSMVarAdapter bool isVarDelayed(const std::string &) const { return false; } - VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getAccessDim(var.access); } + VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getVarAccessDim(var.access); } private: //---------------------------------------------------------------------------- @@ -170,7 +170,7 @@ class SynapseWUVarAdapter const std::string &getNameSuffix() const{ return m_SG.getName(); } - VarAccessDim getVarDims(const Models::Base::SynapseVar &var) const{ return getAccessDim(var.access); } + VarAccessDim getVarDims(const Models::Base::SynapseVar &var) const{ return getVarAccessDim(var.access); } private: //---------------------------------------------------------------------------- @@ -201,7 +201,7 @@ class SynapseWUPreVarAdapter bool isVarDelayed(const std::string&) const{ return (m_SG.getDelaySteps() != 0); } - VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getAccessDim(var.access); } + VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getVarAccessDim(var.access); } private: //---------------------------------------------------------------------------- @@ -232,7 +232,7 @@ class SynapseWUPostVarAdapter bool isVarDelayed(const std::string&) const{ return (m_SG.getBackPropDelaySteps() != 0); } - VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getAccessDim(var.access); } + VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getVarAccessDim(var.access); } private: //---------------------------------------------------------------------------- diff --git a/include/genn/genn/varAccess.h b/include/genn/genn/varAccess.h index 28faaed309..a8c47262bc 100644 --- a/include/genn/genn/varAccess.h +++ b/include/genn/genn/varAccess.h @@ -128,24 +128,24 @@ inline VarAccessDim operator | (VarAccessDim a, VarAccessDim b) //---------------------------------------------------------------------------- // Free functions //---------------------------------------------------------------------------- -inline VarAccessDim clearDim(VarAccessDim a, VarAccessDim b) +inline VarAccessDim clearVarAccessDim(VarAccessDim a, VarAccessDim b) { return static_cast(static_cast(a) & ~static_cast(b)); } -inline VarAccessDim getAccessDim(NeuronVarAccess v) +inline VarAccessDim getVarAccessDim(NeuronVarAccess v) { return static_cast(static_cast(v) & ~0x1F); } -inline VarAccessDim getAccessDim(SynapseVarAccess v) +inline VarAccessDim getVarAccessDim(SynapseVarAccess v) { return static_cast(static_cast(v) & ~0x1F); } -inline VarAccessDim getAccessDim(CustomUpdateVarAccess v, VarAccessDim popDims) +inline VarAccessDim getVarAccessDim(CustomUpdateVarAccess v, VarAccessDim popDims) { - return clearDim(popDims, static_cast(static_cast(v) & ~0x1F)); + return clearVarAccessDim(popDims, static_cast(static_cast(v) & ~0x1F)); } inline VarAccessMode getVarAccessMode(VarAccessMode v) diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 2aa9e70973..fcc0af2614 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -62,7 +62,7 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", [this, batchSize, &cuEnv](const std::string&, CustomUpdateVarAccess d) { - return getVarIndex(batchSize, getAccessDim(d, getArchetype().getDims()), "$(id)"); + return getVarIndex(batchSize, getVarAccessDim(d, getArchetype().getDims()), "$(id)"); }); // Create an environment which caches variable references in local variables if they are accessed @@ -189,7 +189,7 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", [this, batchSize, &cuEnv](const std::string&, CustomUpdateVarAccess d) { - return getVarIndex(batchSize, getAccessDim(d, getArchetype().getDims()), "$(id_syn)"); + return getVarIndex(batchSize, getVarAccessDim(d, getArchetype().getDims()), "$(id_syn)"); }); // Create an environment which caches variable references in local variables if they are accessed diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 4aab494088..453b5e5903 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -1110,8 +1110,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, std::vector neuronStatePushPullFunctions; for(const auto &var : neuronModel->getVars()) { const auto &varInit = n.second.getVarInitialisers().at(var.name); - const unsigned int numCopies = getNumVarCopies(getAccessDim(var.access), batchSize); - const unsigned int numElements = getNumNeuronVarElements(getAccessDim(var.access), n.second.getNumNeurons()); + const unsigned int numCopies = getNumVarCopies(getVarAccessDim(var.access), batchSize); + const unsigned int numElements = getNumNeuronVarElements(getVarAccessDim(var.access), n.second.getNumNeurons()); const size_t count = n.second.isVarQueueRequired(var.name) ? numCopies * numElements * n.second.getNumDelaySlots() : numCopies * numElements; const bool autoInitialized = !Utils::areTokensEmpty(varInit.getCodeTokens()); const auto resolvedType = var.type.resolve(modelMerged.getModel().getTypeContext()); @@ -1429,7 +1429,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const bool autoInitialized = !Utils::areTokensEmpty(varInit.getCodeTokens()); const auto resolvedType = wuVar.type.resolve(modelMerged.getModel().getTypeContext()); if(individualWeights || kernelWeights) { - const size_t size = getSynapseVarSize(getAccessDim(wuVar.access), + const size_t size = getSynapseVarSize(getVarAccessDim(wuVar.access), backend, s.second, batchSize); genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 67191e3326..f108f4bee5 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -42,7 +42,7 @@ void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend *this, ng, getTypeContext(), csEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", [batchSize, &ng](const std::string&, NeuronVarAccess d) { - return ng.getVarIndex(batchSize, getAccessDim(d), "$(id)"); + return ng.getVarIndex(batchSize, getVarAccessDim(d), "$(id)"); }); // Pretty print code back to environment @@ -123,7 +123,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env *this, ng, getTypeContext(), psmEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", [batchSize, &ng](const std::string&, NeuronVarAccess d) { - return ng.getVarIndex(batchSize, getAccessDim(d), "$(id)"); + return ng.getVarIndex(batchSize, getVarAccessDim(d), "$(id)"); }); // Pretty print code back to environment @@ -204,11 +204,11 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", [batchSize, delayed, &synEnv, &ng](const std::string&, NeuronVarAccess d) { - return ng.getReadVarIndex(delayed, batchSize, getAccessDim(d), "$(id)"); + return ng.getReadVarIndex(delayed, batchSize, getVarAccessDim(d), "$(id)"); }, [batchSize, delayed, &synEnv, &ng](const std::string&, NeuronVarAccess d) { - return ng.getWriteVarIndex(delayed, batchSize, getAccessDim(d), "$(id)"); + return ng.getWriteVarIndex(delayed, batchSize, getVarAccessDim(d), "$(id)"); }, [delayed](const std::string&, NeuronVarAccess) { @@ -243,7 +243,7 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::genCopyDelayedVars(EnvironmentEx // **YUCK** this a bit sketchy as fields may not have been added - could add fields here but need to guarantee uniqueness for(const auto &v : getArchetype().getWUModel()->getPostVars()) { if(getVarAccessMode(v.access) == VarAccessMode::READ_WRITE) { - const VarAccessDim varDims = getAccessDim(v.access); + const VarAccessDim varDims = getVarAccessDim(v.access); env.print("group->" + v.name + suffix + "[" + ng.getWriteVarIndex(true, batchSize, varDims, "$(id)") + "] = "); env.printLine("group->" + v.name + suffix + "[" + ng.getReadVarIndex(true, batchSize, varDims, "$(id)") + "];"); } @@ -296,11 +296,11 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", [batchSize, delayed, &ng](const std::string&, NeuronVarAccess d) { - return ng.getReadVarIndex(delayed, batchSize, getAccessDim(d), "$(id)"); + return ng.getReadVarIndex(delayed, batchSize, getVarAccessDim(d), "$(id)"); }, [batchSize, delayed, &ng](const std::string&, NeuronVarAccess d) { - return ng.getWriteVarIndex(delayed, batchSize, getAccessDim(d), "$(id)"); + return ng.getWriteVarIndex(delayed, batchSize, getVarAccessDim(d), "$(id)"); }, [delayed](const std::string&, NeuronVarAccess) { @@ -335,7 +335,7 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::genCopyDelayedVars(EnvironmentEx // **YUCK** this a bit sketchy as fields may not have been added - could add fields here but need to guarantee uniqueness for(const auto &v : getArchetype().getWUModel()->getPreVars()) { if(getVarAccessMode(v.access) == VarAccessMode::READ_WRITE) { - const VarAccessDim varDims = getAccessDim(v.access); + const VarAccessDim varDims = getVarAccessDim(v.access); env.print("group->" + v.name + suffix + "[" + ng.getWriteVarIndex(true, batchSize, varDims, "$(id)") + "] = "); env.printLine("group->" + v.name + suffix + "[" + ng.getReadVarIndex(true, batchSize, varDims, "$(id)") + "];"); } @@ -515,12 +515,12 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E [batchSize, &neuronEnv, this](const std::string &varName, NeuronVarAccess d) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); - return getReadVarIndex(delayed, batchSize, getAccessDim(d), "$(id)") ; + return getReadVarIndex(delayed, batchSize, getVarAccessDim(d), "$(id)") ; }, [batchSize, &neuronEnv, this](const std::string &varName, NeuronVarAccess d) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); - return getWriteVarIndex(delayed, batchSize, getAccessDim(d), "$(id)") ; + return getWriteVarIndex(delayed, batchSize, getVarAccessDim(d), "$(id)") ; }, [this](const std::string &varName, NeuronVarAccess) { diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 34fb123854..e70f102f0c 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -32,13 +32,13 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa backend.getDeviceVarPrefix(), [&sg, batchSize](NeuronVarAccess a, const std::string&) { - return sg.getPreWUVarIndex(batchSize, getAccessDim(a), "$(id_pre)"); + return sg.getPreWUVarIndex(batchSize, getVarAccessDim(a), "$(id_pre)"); }, "", true); synEnv.template addVars( backend.getDeviceVarPrefix(), [&sg, batchSize](NeuronVarAccess a, const std::string&) { - return sg.getPostWUVarIndex(batchSize, getAccessDim(a), "$(id_post)"); + return sg.getPostWUVarIndex(batchSize, getVarAccessDim(a), "$(id_post)"); }, "", true); @@ -80,7 +80,7 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa backend.getDeviceVarPrefix(), [&sg, batchSize](SynapseVarAccess a, const std::string&) { - return sg.getSynVarIndex(batchSize, getAccessDim(a), "$(id_syn)"); + return sg.getSynVarIndex(batchSize, getVarAccessDim(a), "$(id_syn)"); }); } // Otherwise, if weights are procedual @@ -123,7 +123,7 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa backend.getDeviceVarPrefix(), [&sg, batchSize](SynapseVarAccess a, const std::string&) { - return sg.getKernelVarIndex(batchSize, getAccessDim(a), "$(id_kernel)"); + return sg.getKernelVarIndex(batchSize, getVarAccessDim(a), "$(id_kernel)"); }); } diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index 84f6f46610..7577f2c74b 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -42,22 +42,22 @@ VarAccessDim VarReference::getVarDims() const // remove dimensions from those of update [](const CURef &ref) { - return getAccessDim(ref.var.access, ref.group->getDims()); + return getVarAccessDim(ref.var.access, ref.group->getDims()); }, // Otherwise, if reference is to the presynaptic variables of a custom connectivity update, // remove BATCH dimension as these are never batched [](const CCUPreRef &ref) { - return clearDim(getAccessDim(ref.var.access), VarAccessDim::BATCH); + return clearVarAccessDim(getVarAccessDim(ref.var.access), VarAccessDim::BATCH); }, // Otherwise, if reference is to the postsynaptic variables of a custom connectivity update, // remove BATCH dimension as these are never batched [](const CCUPostRef &ref) { - return clearDim(getAccessDim(ref.var.access), VarAccessDim::BATCH); + return clearVarAccessDim(getVarAccessDim(ref.var.access), VarAccessDim::BATCH); }, // Otherwise, use dimensionality directly - [](const auto &ref) { return getAccessDim(ref.var.access); }}, + [](const auto &ref) { return getVarAccessDim(ref.var.access); }}, m_Detail); } //---------------------------------------------------------------------------- @@ -210,16 +210,16 @@ VarAccessDim WUVarReference::getVarDims() const // remove dimensions from those of update [](const CURef &ref) { - return getAccessDim(ref.var.access, ref.group->getDims()); + return getVarAccessDim(ref.var.access, ref.group->getDims()); }, // Otherwise, if reference is to the synaptic variables of a custom connectivity update, // remove BATCH dimension as these are never batched [](const CCURef &ref) { - return clearDim(getAccessDim(ref.var.access), VarAccessDim::BATCH); + return clearVarAccessDim(getVarAccessDim(ref.var.access), VarAccessDim::BATCH); }, // Otherwise, use dimensionality directly - [](const WURef &ref){ return getAccessDim(ref.var.access); }}, + [](const WURef &ref){ return getVarAccessDim(ref.var.access); }}, m_Detail); } //---------------------------------------------------------------------------- @@ -276,7 +276,7 @@ std::optional WUVarReference::getTransposeVarDims() const [](const WURef &ref)->std::optional { if(ref.transposeVar) { - return getAccessDim(ref.transposeVar->access); + return getVarAccessDim(ref.transposeVar->access); } else { return std::nullopt; From e39eb985f2fcee4f1336672d1b5c47785f91dd68 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 8 Sep 2023 17:36:00 +0100 Subject: [PATCH 49/60] updated SET_PRE_VARS and SET_POST_VARS macros --- include/genn/genn/customConnectivityUpdateModels.h | 4 ++-- include/genn/genn/weightUpdateModels.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/genn/genn/customConnectivityUpdateModels.h b/include/genn/genn/customConnectivityUpdateModels.h index d2f5f4d2ed..f97cc4a27a 100644 --- a/include/genn/genn/customConnectivityUpdateModels.h +++ b/include/genn/genn/customConnectivityUpdateModels.h @@ -7,8 +7,8 @@ //---------------------------------------------------------------------------- // Macros //---------------------------------------------------------------------------- -#define SET_PRE_VARS(...) virtual VarVec getPreVars() const override{ return __VA_ARGS__; } -#define SET_POST_VARS(...) virtual VarVec getPostVars() const override{ return __VA_ARGS__; } +#define SET_PRE_VARS(...) virtual std::vector getPreVars() const override{ return __VA_ARGS__; } +#define SET_POST_VARS(...) virtual std::vector getPostVars() const override{ return __VA_ARGS__; } #define SET_VAR_REFS(...) virtual VarRefVec getVarRefs() const override{ return __VA_ARGS__; } #define SET_PRE_VAR_REFS(...) virtual VarRefVec getPreVarRefs() const override{ return __VA_ARGS__; } diff --git a/include/genn/genn/weightUpdateModels.h b/include/genn/genn/weightUpdateModels.h index 08d8b008cc..3c2538ab91 100644 --- a/include/genn/genn/weightUpdateModels.h +++ b/include/genn/genn/weightUpdateModels.h @@ -17,8 +17,8 @@ #define SET_PRE_DYNAMICS_CODE(PRE_DYNAMICS_CODE) virtual std::string getPreDynamicsCode() const override{ return PRE_DYNAMICS_CODE; } #define SET_POST_DYNAMICS_CODE(POST_DYNAMICS_CODE) virtual std::string getPostDynamicsCode() const override{ return POST_DYNAMICS_CODE; } -#define SET_PRE_VARS(...) virtual VarVec getPreVars() const override{ return __VA_ARGS__; } -#define SET_POST_VARS(...) virtual VarVec getPostVars() const override{ return __VA_ARGS__; } +#define SET_PRE_VARS(...) virtual std::vector getPreVars() const override{ return __VA_ARGS__; } +#define SET_POST_VARS(...) virtual std::vector getPostVars() const override{ return __VA_ARGS__; } //---------------------------------------------------------------------------- // GeNN::WeightUpdateModels::Base From 0d40ab245d785d0a6d425a77f34facdcea30c9f5 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 8 Sep 2023 17:36:09 +0100 Subject: [PATCH 50/60] tests compile --- tests/unit/customConnectivityUpdate.cc | 14 +++++------ tests/unit/customUpdate.cc | 32 +++++++++++++------------- tests/unit/modelSpec.cc | 6 ++--- tests/unit/modelSpecMerged.cc | 8 +++---- tests/unit/models.cc | 6 ++--- tests/unit/neuronGroup.cc | 10 ++++---- tests/unit/neuronModels.cc | 2 +- tests/unit/synapseGroup.cc | 18 +++++++-------- tests/unit/weightUpdateModels.cc | 4 ++-- 9 files changed, 50 insertions(+), 50 deletions(-) diff --git a/tests/unit/customConnectivityUpdate.cc b/tests/unit/customConnectivityUpdate.cc index 61490f413a..ac65ba2fe6 100644 --- a/tests/unit/customConnectivityUpdate.cc +++ b/tests/unit/customConnectivityUpdate.cc @@ -19,7 +19,7 @@ class StaticPulseDendriticDelayReverse : public WeightUpdateModels::Base public: DECLARE_SNIPPET(StaticPulseDendriticDelayReverse); - SET_VARS({{"d", "uint8_t", SynapseVarAccess::READ_ONLY}, {"g", "scalar", SynapseVarAccess::READ_ONLY}}); + SET_SYNAPSE_VARS({{"d", "uint8_t", SynapseVarAccess::READ_ONLY}, {"g", "scalar", SynapseVarAccess::READ_ONLY}}); SET_SIM_CODE("addToPostDelay(g, d);\n"); }; @@ -31,7 +31,7 @@ class Sum : public CustomUpdateModels::Base SET_UPDATE_CODE("sum += a;\n"); - SET_VARS({{"sum", "scalar"}}); + SET_CUSTOM_UPDATE_VARS({{"sum", "scalar"}}); SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_ONLY}}); }; IMPLEMENT_SNIPPET(Sum); @@ -41,7 +41,7 @@ class RemoveSynapse : public CustomConnectivityUpdateModels::Base public: DECLARE_SNIPPET(RemoveSynapse); - SET_VARS({{"a", "scalar"}}); + SET_SYNAPSE_VARS({{"a", "scalar"}}); SET_ROW_UPDATE_CODE( "for_each_synapse {\n" " if(id_post == (id_pre + 1)) {\n" @@ -57,7 +57,7 @@ class RemoveSynapseVarRef : public CustomConnectivityUpdateModels::Base public: DECLARE_SNIPPET(RemoveSynapseVarRef); - SET_VARS({{"a", "scalar"}}); + SET_SYNAPSE_VARS({{"a", "scalar"}}); SET_VAR_REFS({{"b", "scalar"}}); SET_ROW_UPDATE_CODE( "for_each_synapse {\n" @@ -108,7 +108,7 @@ class Cont : public WeightUpdateModels::Base public: DECLARE_SNIPPET(Cont); - SET_VARS({{"g", "scalar"}}); + SET_SYNAPSE_VARS({{"g", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE( "addToPost(g * V_pre);\n"); @@ -120,7 +120,7 @@ class ContPost : public WeightUpdateModels::Base public: DECLARE_SNIPPET(ContPost); - SET_VARS({{"g", "scalar"}}); + SET_SYNAPSE_VARS({{"g", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE( "addToPost(g * V_post);\n"); @@ -132,7 +132,7 @@ bool hasVarRef(const std::vector &varRefs, const std::st return std::find_if(varRefs.cbegin(), varRefs.cend(), [&targetName, &varName](const Models::WUVarReference &r) { - return (r.getTargetName() == targetName) && (r.getVar().name == varName); + return (r.getTargetName() == targetName) && (r.getVarName() == varName); }) != varRefs.cend(); } } // Anonymous namespace diff --git a/tests/unit/customUpdate.cc b/tests/unit/customUpdate.cc index b282dbb3aa..3410f08d82 100644 --- a/tests/unit/customUpdate.cc +++ b/tests/unit/customUpdate.cc @@ -23,9 +23,9 @@ class IzhikevichVariableShared : public NeuronModels::Izhikevich DECLARE_SNIPPET(IzhikevichVariableShared); SET_PARAM_NAMES({}); - SET_VARS({{"V","scalar"}, {"U", "scalar"}, - {"a", "scalar", NeuronVarAccess::READ_ONLY_SHARED_NEURON}, {"b", "scalar", NeuronVarAccess::READ_ONLY_SHARED_NEURON}, - {"c", "scalar", NeuronVarAccess::READ_ONLY_SHARED_NEURON}, {"d", "scalar", NeuronVarAccess::READ_ONLY_SHARED_NEURON}}); + SET_NEURON_VARS({{"V","scalar"}, {"U", "scalar"}, + {"a", "scalar", NeuronVarAccess::READ_ONLY_SHARED_NEURON}, {"b", "scalar", NeuronVarAccess::READ_ONLY_SHARED_NEURON}, + {"c", "scalar", NeuronVarAccess::READ_ONLY_SHARED_NEURON}, {"d", "scalar", NeuronVarAccess::READ_ONLY_SHARED_NEURON}}); }; IMPLEMENT_SNIPPET(IzhikevichVariableShared); @@ -34,10 +34,10 @@ class StaticPulseDendriticDelaySplit : public WeightUpdateModels::Base public: DECLARE_SNIPPET(StaticPulseDendriticDelaySplit); - SET_VARS({{"gCommon", "scalar", SynapseVarAccess::READ_ONLY}, - {"g", "scalar", SynapseVarAccess::READ_ONLY_DUPLICATE}, - {"dCommon", "scalar", SynapseVarAccess::READ_ONLY}, - {"d", "scalar", SynapseVarAccess::READ_ONLY_DUPLICATE}}); + SET_SYNAPSE_VARS({{"gCommon", "scalar", SynapseVarAccess::READ_ONLY}, + {"g", "scalar", SynapseVarAccess::READ_ONLY_DUPLICATE}, + {"dCommon", "scalar", SynapseVarAccess::READ_ONLY}, + {"d", "scalar", SynapseVarAccess::READ_ONLY_DUPLICATE}}); SET_SIM_CODE("$(addToInSynDelay, $(gCommon) + $(g), $(dCommon) + $(d));\n"); }; @@ -49,7 +49,7 @@ class Sum : public CustomUpdateModels::Base SET_UPDATE_CODE("$(sum) = $(a) + $(b);\n"); - SET_VARS({{"sum", "scalar"}}); + SET_CUSTOM_UPDATE_VARS({{"sum", "scalar"}}); SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_ONLY}, {"b", "scalar", VarAccessMode::READ_ONLY}}); }; @@ -61,7 +61,7 @@ class Sum2 : public CustomUpdateModels::Base SET_UPDATE_CODE("$(a) = $(mult) * ($(a) + $(b));\n"); - SET_VARS({{"mult", "scalar", CustomUpdateVarAccess::READ_ONLY}}); + SET_CUSTOM_UPDATE_VARS({{"mult", "scalar", CustomUpdateVarAccess::READ_ONLY}}); SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_WRITE}, {"b", "scalar", VarAccessMode::READ_ONLY}}); }; @@ -73,7 +73,7 @@ class Sum3 : public CustomUpdateModels::Base SET_UPDATE_CODE("$(sum) = $(scale) * ($(a) + $(b));\n"); - SET_VARS({{"sum", "scalar"}, {"scale", "scalar", CustomUpdateVarAccess::READ_ONLY_SHARED_NEURON}}); + SET_CUSTOM_UPDATE_VARS({{"sum", "scalar"}, {"scale", "scalar", CustomUpdateVarAccess::READ_ONLY_SHARED_NEURON}}); SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_WRITE}, {"b", "scalar", VarAccessMode::READ_ONLY}}); }; @@ -131,7 +131,7 @@ class Cont : public WeightUpdateModels::Base public: DECLARE_SNIPPET(Cont); - SET_VARS({{"g", "scalar"}}); + SET_SYNAPSE_VARS({{"g", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE( "addToPost(g * V_pre);\n"); @@ -143,7 +143,7 @@ class Cont2 : public WeightUpdateModels::Base public: DECLARE_SNIPPET(Cont2); - SET_VARS({{"g", "scalar"}, {"x", "scalar"}}); + SET_SYNAPSE_VARS({{"g", "scalar"}, {"x", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE( "addToPost((g + x) * V_pre);\n"); @@ -169,8 +169,8 @@ class ReduceDouble : public CustomUpdateModels::Base "reduction1 = var1;\n" "reduction2 = var2;\n"); - SET_VARS({{"reduction1", "scalar", CustomUpdateVarAccess::REDUCE_BATCH_SUM}, - {"reduction2", "scalar", CustomUpdateVarAccess::REDUCE_NEURON_SUM}}); + SET_CUSTOM_UPDATE_VARS({{"reduction1", "scalar", CustomUpdateVarAccess::REDUCE_BATCH_SUM}, + {"reduction2", "scalar", CustomUpdateVarAccess::REDUCE_NEURON_SUM}}); SET_VAR_REFS({{"var1", "scalar", VarAccessMode::READ_ONLY}, {"var2", "scalar", VarAccessMode::READ_ONLY}}); @@ -183,7 +183,7 @@ class ReduceSharedVar : public CustomUpdateModels::Base SET_UPDATE_CODE("reduction = var;\n"); - SET_VARS({{"reduction", "scalar", CustomUpdateVarAccess::REDUCE_BATCH_SUM}}) + SET_CUSTOM_UPDATE_VARS({{"reduction", "scalar", CustomUpdateVarAccess::REDUCE_BATCH_SUM}}) SET_VAR_REFS({{"var", "scalar", VarAccessMode::READ_ONLY}}); }; IMPLEMENT_SNIPPET(ReduceSharedVar); @@ -195,7 +195,7 @@ class ReduceNeuronSharedVar : public CustomUpdateModels::Base SET_UPDATE_CODE("reduction = var;\n"); - SET_VARS({{"reduction", "scalar", CustomUpdateVarAccess::REDUCE_NEURON_SUM}}) + SET_CUSTOM_UPDATE_VARS({{"reduction", "scalar", CustomUpdateVarAccess::REDUCE_NEURON_SUM}}) SET_VAR_REFS({{"var", "scalar", VarAccessMode::READ_ONLY}}); }; IMPLEMENT_SNIPPET(ReduceNeuronSharedVar); diff --git a/tests/unit/modelSpec.cc b/tests/unit/modelSpec.cc index 2118481ece..d89f242661 100644 --- a/tests/unit/modelSpec.cc +++ b/tests/unit/modelSpec.cc @@ -24,7 +24,7 @@ class AlphaCurr : public PostsynapticModels::Base SET_PARAM_NAMES({"tau"}); - SET_VARS({{"x", "scalar"}}); + SET_NEURON_VARS({{"x", "scalar"}}); SET_DERIVED_PARAMS({ {"expDecay", [](const ParamValues &pars, double dt) { return std::exp(-dt / pars.at("tau")); }}, @@ -39,7 +39,7 @@ class Sum : public CustomUpdateModels::Base SET_UPDATE_CODE("$(sum) = $(a) + $(b);\n"); - SET_VARS({{"sum", "scalar"}}); + SET_CUSTOM_UPDATE_VARS({{"sum", "scalar"}}); SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_ONLY}, {"b", "scalar", VarAccessMode::READ_ONLY}}); }; @@ -50,7 +50,7 @@ class RemoveSynapse : public CustomConnectivityUpdateModels::Base public: DECLARE_SNIPPET(RemoveSynapse); - SET_VARS({{"a", "scalar"}}); + SET_SYNAPSE_VARS({{"a", "scalar"}}); SET_ROW_UPDATE_CODE( "for_each_synapse{\n" " if(id_post == (id_pre + 1)) {\n" diff --git a/tests/unit/modelSpecMerged.cc b/tests/unit/modelSpecMerged.cc index 2bbdc0eff7..2f0651990b 100644 --- a/tests/unit/modelSpecMerged.cc +++ b/tests/unit/modelSpecMerged.cc @@ -32,7 +32,7 @@ class AlphaCurr : public PostsynapticModels::Base SET_PARAM_NAMES({"tau"}); - SET_VARS({{"x", "scalar"}}); + SET_NEURON_VARS({{"x", "scalar"}}); SET_DERIVED_PARAMS({ {"expDecay", [](const ParamValues &pars, double dt) { return std::exp(-dt / pars.at("tau")); }}, @@ -53,7 +53,7 @@ class STDPAdditive : public WeightUpdateModels::Base "Wmin", // 4 - Minimum weight "Wmax"}); // 5 - Maximum weight - SET_VARS({{"g", "scalar"}}); + SET_SYNAPSE_VARS({{"g", "scalar"}}); SET_PRE_VARS({{"preTrace", "scalar"}}); SET_POST_VARS({{"postTrace", "scalar"}}); @@ -89,7 +89,7 @@ class Sum : public CustomUpdateModels::Base SET_UPDATE_CODE("sum = a + b;\n"); - SET_VARS({{"sum", "scalar"}}); + SET_CUSTOM_UPDATE_VARS({{"sum", "scalar"}}); SET_PARAM_NAMES({"b"}); SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_ONLY}}); }; @@ -127,7 +127,7 @@ class RemoveSynapsePrePost : public CustomConnectivityUpdateModels::Base public: DECLARE_SNIPPET(RemoveSynapsePrePost); - SET_VARS({{"g", "scalar"}}); + SET_SYNAPSE_VARS({{"g", "scalar"}}); SET_PRE_VARS({{"preThresh", "scalar"}}); SET_POST_VARS({{"postThresh", "scalar"}}); SET_ROW_UPDATE_CODE( diff --git a/tests/unit/models.cc b/tests/unit/models.cc index 18335c61a2..972985166d 100644 --- a/tests/unit/models.cc +++ b/tests/unit/models.cc @@ -24,7 +24,7 @@ class AlphaCurr : public PostsynapticModels::Base SET_PARAM_NAMES({"tau"}); - SET_VARS({{"x", "scalar"}}); + SET_NEURON_VARS({{"x", "scalar"}}); SET_DERIVED_PARAMS({ {"expDecay", [](const ParamValues &pars, double dt) { return std::exp(-dt / pars.at("tau")); }}, @@ -48,7 +48,7 @@ class Cont : public WeightUpdateModels::Base public: DECLARE_SNIPPET(Cont); - SET_VARS({{"g", "scalar"}}); + SET_SYNAPSE_VARS({{"g", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE( "addToPost(g * V_pre);\n"); @@ -60,7 +60,7 @@ class ContPrePost : public WeightUpdateModels::Base public: DECLARE_SNIPPET(ContPrePost); - SET_VARS({{"g", "scalar"}}); + SET_SYNAPSE_VARS({{"g", "scalar"}}); SET_PRE_VARS({{"preTrace", "scalar"}}); SET_POST_VARS({{"postTrace", "scalar"}}); diff --git a/tests/unit/neuronGroup.cc b/tests/unit/neuronGroup.cc index 7d621dbb2d..68e8bc6bcc 100644 --- a/tests/unit/neuronGroup.cc +++ b/tests/unit/neuronGroup.cc @@ -19,7 +19,7 @@ class StaticPulseBack : public WeightUpdateModels::Base public: DECLARE_SNIPPET(StaticPulseBack); - SET_VARS({{"g", "scalar", SynapseVarAccess::READ_ONLY}}); + SET_SYNAPSE_VARS({{"g", "scalar", SynapseVarAccess::READ_ONLY}}); SET_SIM_CODE( "$(addToInSyn, $(g));\n" @@ -79,7 +79,7 @@ class AlphaCurr : public PostsynapticModels::Base SET_PARAM_NAMES({"tau"}); - SET_VARS({{"x", "scalar"}}); + SET_NEURON_VARS({{"x", "scalar"}}); SET_DERIVED_PARAMS({ {"expDecay", [](const ParamValues &pars, double dt) { return std::exp(-dt / pars.at("tau")); }}, @@ -122,7 +122,7 @@ class LIFAdditional : public NeuronModels::Base {"ExpTC", [](const ParamValues &pars, double dt) { return std::exp(-dt / pars.at("TauM")); }}, {"Rmembrane", [](const ParamValues &pars, double) { return pars.at("TauM") / pars.at("C"); }}}); - SET_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); + SET_NEURON_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); }; IMPLEMENT_SNIPPET(LIFAdditional); @@ -163,7 +163,7 @@ class LIFRandom : public NeuronModels::Base {"ExpTC", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("TauM")); }}, {"Rmembrane", [](const ParamValues &pars, double){ return pars.at("TauM") / pars.at("C"); }}}); - SET_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); + SET_NEURON_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); }; IMPLEMENT_SNIPPET(LIFRandom); @@ -176,7 +176,7 @@ class STDPAdditive : public WeightUpdateModels::Base SET_DERIVED_PARAMS({ {"tauPlusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauPlus")); }}, {"tauMinusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauMinus")); }}}); - SET_VARS({{"g", "scalar"}}); + SET_SYNAPSE_VARS({{"g", "scalar"}}); SET_PRE_VARS({{"preTrace", "scalar"}}); SET_POST_VARS({{"postTrace", "scalar"}}); diff --git a/tests/unit/neuronModels.cc b/tests/unit/neuronModels.cc index fc2c244a23..30eb0ba3b1 100644 --- a/tests/unit/neuronModels.cc +++ b/tests/unit/neuronModels.cc @@ -42,7 +42,7 @@ class LIFCopy : public NeuronModels::Base {"ExpTC", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("TauM")); }}, {"Rmembrane", [](const ParamValues &pars, double){ return pars.at("TauM") / pars.at("C"); }}}); - SET_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); + SET_NEURON_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); SET_NEEDS_AUTO_REFRACTORY(false); }; diff --git a/tests/unit/synapseGroup.cc b/tests/unit/synapseGroup.cc index 1be3cdd278..d45a1d1a1b 100644 --- a/tests/unit/synapseGroup.cc +++ b/tests/unit/synapseGroup.cc @@ -26,7 +26,7 @@ class STDPAdditive : public WeightUpdateModels::Base SET_DERIVED_PARAMS({ {"tauPlusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauPlus")); }}, {"tauMinusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauMinus")); }}}); - SET_VARS({{"g", "scalar"}}); + SET_SYNAPSE_VARS({{"g", "scalar"}}); SET_PRE_VARS({{"preTrace", "scalar"}}); SET_POST_VARS({{"postTrace", "scalar"}}); @@ -58,7 +58,7 @@ class STDPAdditiveEGPWMinMax : public WeightUpdateModels::Base SET_DERIVED_PARAMS({ {"tauPlusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauPlus")); }}, {"tauMinusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauMinus")); }}}); - SET_VARS({{"g", "scalar"}}); + SET_SYNAPSE_VARS({{"g", "scalar"}}); SET_PRE_VARS({{"preTrace", "scalar"}}); SET_POST_VARS({{"postTrace", "scalar"}}); SET_EXTRA_GLOBAL_PARAMS({{"Wmin", "scalar"}, {"Wmax", "scalar"}}); @@ -92,7 +92,7 @@ class STDPAdditiveEGPSpike : public WeightUpdateModels::Base SET_DERIVED_PARAMS({ {"tauPlusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauPlus")); }}, {"tauMinusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauMinus")); }}}); - SET_VARS({{"g", "scalar"}}); + SET_SYNAPSE_VARS({{"g", "scalar"}}); SET_PRE_VARS({{"preTrace", "scalar"}}); SET_POST_VARS({{"postTrace", "scalar"}}); SET_EXTRA_GLOBAL_PARAMS({{"S", "scalar"}}); @@ -122,7 +122,7 @@ class STDPAdditiveEGPDynamics : public WeightUpdateModels::Base public: DECLARE_SNIPPET(STDPAdditiveEGPDynamics); SET_PARAM_NAMES({"Aplus", "Aminus", "Wmin", "Wmax"}); - SET_VARS({{"g", "scalar"}}); + SET_SYNAPSE_VARS({{"g", "scalar"}}); SET_PRE_VARS({{"preTrace", "scalar"}}); SET_POST_VARS({{"postTrace", "scalar"}}); SET_EXTRA_GLOBAL_PARAMS({{"tauPlusDecay", "scalar"}, {"tauMinusDecay", "scalar"}}); @@ -152,7 +152,7 @@ class Continuous : public WeightUpdateModels::Base public: DECLARE_SNIPPET(Continuous); - SET_VARS({{"g", "scalar"}}); + SET_SYNAPSE_VARS({{"g", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE("addToPost(g * V_pre);\n"); }; @@ -196,7 +196,7 @@ class StaticPulseDynamics : public WeightUpdateModels::Base public: DECLARE_SNIPPET(StaticPulseDynamics); - SET_VARS({ {"g", "scalar"} }); + SET_SYNAPSE_VARS({ {"g", "scalar"} }); SET_SIM_CODE("addToPost(g);\n"); SET_SYNAPSE_DYNAMICS_CODE("g *= 0.99;\n"); @@ -208,7 +208,7 @@ class StaticPulsePostLearn : public WeightUpdateModels::Base public: DECLARE_SNIPPET(StaticPulsePostLearn); - SET_VARS({ {"g", "scalar"} }); + SET_SYNAPSE_VARS({ {"g", "scalar"} }); SET_SIM_CODE("addToPost(g);\n"); SET_LEARN_POST_CODE("g *= 0.99;\n"); @@ -243,7 +243,7 @@ class Sum : public CustomUpdateModels::Base SET_UPDATE_CODE("sum = a + b;\n"); - SET_VARS({{"sum", "scalar"}}); + SET_CUSTOM_UPDATE_VARS({{"sum", "scalar"}}); SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_ONLY}, {"b", "scalar", VarAccessMode::READ_ONLY}}); }; @@ -284,7 +284,7 @@ class LIFAdditional : public NeuronModels::Base {"ExpTC", [](const ParamValues &pars, double dt) { return std::exp(-dt / pars.at("TauM")); }}, {"Rmembrane", [](const ParamValues &pars, double) { return pars.at("TauM") / pars.at("C"); }}}); - SET_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); + SET_NEURON_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); }; IMPLEMENT_SNIPPET(LIFAdditional); } // Anonymous namespace diff --git a/tests/unit/weightUpdateModels.cc b/tests/unit/weightUpdateModels.cc index ab76154296..dfef43267e 100644 --- a/tests/unit/weightUpdateModels.cc +++ b/tests/unit/weightUpdateModels.cc @@ -17,7 +17,7 @@ class PiecewiseSTDPCopy : public WeightUpdateModels::Base public: SET_PARAM_NAMES({"tLrn", "tChng", "tDecay", "tPunish10", "tPunish01", "gMax", "gMid", "gSlope", "tauShift", "gSyn0"}); - SET_VARS({{"g", "scalar"}, {"gRaw", "scalar"}}); + SET_SYNAPSE_VARS({{"g", "scalar"}, {"gRaw", "scalar"}}); SET_SIM_CODE( "addToPost(g);\n" @@ -65,7 +65,7 @@ class STDPAdditive : public WeightUpdateModels::Base SET_DERIVED_PARAMS({ {"tauPlusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauPlus")); }}, {"tauMinusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauMinus")); }}}); - SET_VARS({{"g", "scalar"}}); + SET_SYNAPSE_VARS({{"g", "scalar"}}); SET_PRE_VARS({{"preTrace", "scalar"}}); SET_POST_VARS({{"postTrace", "scalar"}}); From a48b5d440886b3d898b4fd88a1293a0a7070cde9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 8 Sep 2023 17:42:19 +0100 Subject: [PATCH 51/60] fixed bug --- src/genn/genn/weightUpdateModels.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/genn/weightUpdateModels.cc b/src/genn/genn/weightUpdateModels.cc index e3ec73d8d0..ee064f2370 100644 --- a/src/genn/genn/weightUpdateModels.cc +++ b/src/genn/genn/weightUpdateModels.cc @@ -82,7 +82,7 @@ void Base::validate(const std::unordered_map ¶mValues, Utils::validateVecNames(getPostVars(), "Presynaptic variable"); // Validate variable initialisers - Utils::validateInitialisers(vars, preVarValues, "variable", description); + Utils::validateInitialisers(vars, varValues, "variable", description); Utils::validateInitialisers(preVars, preVarValues, "presynaptic variable", description); Utils::validateInitialisers(postVars, postVarValues, "postsynaptic variable", description); } From 2d17150a8526d2a9652ff91d5655af7a56545e8d Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 8 Sep 2023 17:51:12 +0100 Subject: [PATCH 52/60] fixed CUDA backend --- include/genn/backends/cuda/backend.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index 3552ee9f63..f683368a65 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -394,7 +394,7 @@ class BACKEND_EXPORT Backend : public BackendSIMT // Add NCCL reduction groupEnv.print("CHECK_NCCL_ERRORS(ncclAllReduce($(_" + v.name + "), $(_" + v.name + "), $(_size)"); - groupEnv.printLine(", " + getNCCLType(resolvedType) + ", " + getNCCLReductionType(v.access) + ", ncclCommunicator, 0));"); + groupEnv.printLine(", " + getNCCLType(resolvedType) + ", " + getNCCLReductionType(getVarAccessMode(v.access)) + ", ncclCommunicator, 0));"); } } @@ -408,7 +408,7 @@ class BACKEND_EXPORT Backend : public BackendSIMT [this, v](const auto &g, size_t) { const auto varRef = g.getVarReferences().at(v.name); - return getDeviceVarPrefix() + varRef.getVar().name + varRef.getTargetName(); ; + return getDeviceVarPrefix() + varRef.getVarName() + varRef.getTargetName(); ; }); // Add NCCL reduction From 0011919815747fcda332b5deda26c79fdee5faf7 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 8 Sep 2023 18:04:53 +0100 Subject: [PATCH 53/60] start updating GeNN wrapper --- pygenn/genn_groups.py | 25 ++++----- pygenn/genn_model.py | 73 +++++++++++++++----------- pygenn/src/genn.cc | 118 ++++++++++++++++++++++-------------------- 3 files changed, 118 insertions(+), 98 deletions(-) diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index 0a62d04b28..a69044e128 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -15,7 +15,8 @@ from . import neuron_models, types from .genn import (CustomUpdateWU, SynapseMatrixConnectivity, - SynapseMatrixWeight, VarAccess, VarAccessDim, VarLocation) + SynapseMatrixWeight, VarAccessDim, VarLocation) +from .genn import get_var_access_dim from .model_preprocessor import prepare_model, ExtraGlobalParameter, Variable def _get_num_var_copies(var_dims, batch_size): @@ -355,7 +356,7 @@ def load(self, num_recording_timesteps): # **TODO** delay slots self._load_vars( self.neuron_model.get_vars(), - lambda v: _get_neuron_var_shape(v.access.get_neuron_dims(), + lambda v: _get_neuron_var_shape(get_var_access_dim(v.access), self.size, self._model.batch_size)) @@ -655,7 +656,7 @@ def load(self): if var_loc & VarLocation.HOST: # Determine shape of this variable var_shape = _get_synapse_var_shape( - v.access.get_synapse_dims(), + get_var_access_dim(v.access), self, self._model.batch_size) # Get view @@ -680,7 +681,7 @@ def load(self): if not self._wu_pre_model_fused: self._load_vars( self.wu_model.get_pre_vars(), - lambda v: _get_neuron_var_shape(v.access.get_neuron_dims(), + lambda v: _get_neuron_var_shape(get_var_access_dim(v.access), self.src.size, self._model.batch_size), self.pre_vars, self.get_wu_pre_var_location) @@ -691,7 +692,7 @@ def load(self): if not self._wu_post_model_fused: self._load_vars( self.wu_model.get_post_vars(), - lambda v: _get_neuron_var_shape(v.access.get_neuron_dims(), + lambda v: _get_neuron_var_shape(get_var_access_dim(v.access), self.trg.size, self._model.batch_size), self.post_vars, self.get_wu_post_var_location) @@ -701,7 +702,7 @@ def load(self): # Load postsynaptic update model variables self._load_vars( self.ps_model.get_vars(), - lambda v: _get_neuron_var_shape(v.access.get_neuron_dims(), + lambda v: _get_neuron_var_shape(get_var_access_dim(v.access), self.trg.size, self._model.batch_size), self.psm_vars, self.get_ps_var_location) @@ -840,7 +841,7 @@ def load(self): # Load current source variables self._load_vars( self.current_source_model.get_vars(), - lambda v: _get_neuron_var_shape(v.access.get_neuron_dims(), + lambda v: _get_neuron_var_shape(get_var_access_dim(v.access), self.size, self._model.batch_size)) @@ -875,7 +876,7 @@ def load(self): else 1) self._load_vars( self.custom_update_model.get_vars(), - lambda v: _get_neuron_var_shape(v.access.get_custom_update_dims(self._dims), + lambda v: _get_neuron_var_shape(get_var_access_dim(v.access, self._dims), self.size, batch_size)) self._load_egp() @@ -919,7 +920,7 @@ def load(self): if var_loc & VarLocation.HOST: # Determine shape of this variable var_shape = _get_synapse_var_shape( - v.access.get_custom_update_dims(self._dims), + get_var_access_dim(v.access, self._dims), self.synapse_group, batch_size) # Get view @@ -992,7 +993,7 @@ def load(self): if var_loc & VarLocation.HOST: # Determine shape of this variable var_shape = _get_synapse_var_shape( - v.access.get_synapse_dims(), + get_var_access_dim(v.access), self.synapse_group, 1) resolved_type = var_data.type.resolve(self._model.type_context) @@ -1011,12 +1012,12 @@ def load(self): # Load pre and postsynaptic variables self._load_vars( self.model.get_pre_vars(), - lambda v: _get_neuron_var_shape(v.access.get_neuron_dims(), + lambda v: _get_neuron_var_shape(get_var_access_dim(v.access), self.synapse_group.src.size, 1), self.pre_vars, self.get_pre_var_location) self._load_vars( self.model.get_post_vars(), - lambda v: _get_neuron_var_shape(v.access.get_neuron_dims(), + lambda v: _get_neuron_var_shape(get_var_access_dim(v.access), self.synapse_group.trg.size, 1), self.post_vars, self.get_post_var_location) diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py index 0e4879b0af..5541e07e82 100644 --- a/pygenn/genn_model.py +++ b/pygenn/genn_model.py @@ -59,16 +59,17 @@ # pygenn imports from .genn import (generate_code, init_logging, CurrentSource, - CurrentSourceModelBase, CustomConnectivityUpdate, - CustomConnectivityUpdateModelBase, CustomUpdate, - CustomUpdateModelBase, CustomUpdateWU, DerivedParam, - EGP, EGPRef, InitSparseConnectivitySnippetBase, + CurrentSourceModelBase, CustomConnectivityUpdate, + CustomConnectivityUpdateModelBase, CustomUpdate, + CustomUpdateModelBase, CustomUpdateVar, CustomUpdateWU, + DerivedParam, EGP, EGPRef, + InitSparseConnectivitySnippetBase, InitToeplitzConnectivitySnippetBase, InitVarSnippetBase, - ModelSpecInternal, NeuronGroup, NeuronModelBase, + ModelSpecInternal, NeuronGroup, NeuronModelBase, NeuronVar, ParamVal, PlogSeverity, PostsynapticModelBase, SparseConnectivityInit, SynapseGroup, SynapseMatrixType, - ToeplitzConnectivityInit, UnresolvedType, Var, VarInit, - VarLocation, VarRef, WeightUpdateModelBase) + SynapseVar, ToeplitzConnectivityInit, UnresolvedType, + VarInit, VarLocation, VarRef, WeightUpdateModelBase) from .shared_library_model import (SharedLibraryModelDouble, SharedLibraryModelFloat) @@ -904,13 +905,16 @@ def create_neuron_model(class_name, param_names=None, lambda self: [ParamVal(a[0], a[1], a[2]) for a in additional_input_vars] + if var_name_types is not None: + body["get_vars"] = \ + lambda self: [NeuronVar(*vn) for vn in var_name_types] + if is_auto_refractory_required is not None: body["is_auto_refractory_required"] = \ lambda self: is_auto_refractory_required return create_model(class_name, NeuronModelBase, param_names, - var_name_types, derived_params, - extra_global_params, body) + derived_params, extra_global_params, body) def create_postsynaptic_model(class_name, param_names=None, @@ -948,9 +952,12 @@ def create_postsynaptic_model(class_name, param_names=None, if apply_input_code is not None: body["get_apply_input_code"] = lambda self: dedent(apply_input_code) + if var_name_types is not None: + body["get_vars"] = \ + lambda self: [NeuronVar(*vn) for vn in var_name_types] + return create_model(class_name, PostsynapticModelBase, param_names, - var_name_types, derived_params, - extra_global_params, body) + derived_params, extra_global_params, body) def create_weight_update_model(class_name, param_names=None, @@ -1039,18 +1046,21 @@ def create_weight_update_model(class_name, param_names=None, if post_dynamics_code is not None: body["get_post_dynamics_code"] = lambda self: dedent(post_dynamics_code) - + + if var_name_types is not None: + body["get_vars"] = \ + lambda self: [SynapseVar(*vn) for vn in var_name_types] + if pre_var_name_types is not None: body["get_pre_vars"] = \ - lambda self: [Var(*vn) for vn in pre_var_name_types] + lambda self: [NeuronVar(*vn) for vn in pre_var_name_types] if post_var_name_types is not None: body["get_post_vars"] = \ - lambda self: [Var(*vn) for vn in post_var_name_types] + lambda self: [NeuronVar(*vn) for vn in post_var_name_types] return create_model(class_name, WeightUpdateModelBase, param_names, - var_name_types, derived_params, - extra_global_params, body) + derived_params, extra_global_params, body) def create_current_source_model(class_name, param_names=None, @@ -1085,9 +1095,12 @@ def create_current_source_model(class_name, param_names=None, if injection_code is not None: body["get_injection_code"] = lambda self: dedent(injection_code) + if var_name_types is not None: + body["get_vars"] = \ + lambda self: [NeuronVar(*vn) for vn in var_name_types] + return create_model(class_name, CurrentSourceModelBase, param_names, - var_name_types, derived_params, - extra_global_params, body) + derived_params, extra_global_params, body) def create_custom_update_model(class_name, param_names=None, @@ -1132,13 +1145,16 @@ def create_custom_update_model(class_name, param_names=None, if var_refs is not None: body["get_var_refs"] = lambda self: [VarRef(*v) for v in var_refs] + if var_name_types is not None: + body["get_vars"] = \ + lambda self: [CustomUpdateVar(*vn) for vn in var_name_types] + if extra_global_param_refs is not None: body["get_extra_global_param_refs"] =\ lambda self: [EGPRef(*e) for e in extra_global_param_refs] return create_model(class_name, CustomUpdateModelBase, param_names, - var_name_types, derived_params, - extra_global_params, body) + derived_params, extra_global_params, body) def create_custom_connectivity_update_model(class_name, param_names=None, @@ -1188,6 +1204,10 @@ def create_custom_connectivity_update_model(class_name, if host_update_code is not None: body["get_host_update_code"] = lambda self: dedent(host_update_code) + if var_name_types is not None: + body["get_vars"] = \ + lambda self: [SynapseVar(*vn) for vn in var_name_types] + if pre_var_name_types is not None: body["get_pre_vars"] = \ lambda self: [Var(*vn) for vn in pre_var_name_types] @@ -1208,12 +1228,11 @@ def create_custom_connectivity_update_model(class_name, lambda self: [VarRef(*v) for v in post_var_refs] return create_model(class_name, CustomConnectivityUpdateModelBase, - param_names, var_name_types, derived_params, - extra_global_params, body) + param_names, derived_params, extra_global_params, body) -def create_model(class_name, base, param_names, var_name_types, - derived_params, extra_global_params, custom_body): +def create_model(class_name, base, param_names, derived_params, + extra_global_params, custom_body): """This helper function completes a custom model class creation. This part is common for all model classes and is nearly useless on its own @@ -1230,8 +1249,6 @@ def create_model(class_name, base, param_names, var_name_types, class_name -- name of the new class base -- base class param_names -- list of strings with param names of the model - var_name_types -- list of pairs of strings with varible names and - types of the model derived_params -- list of pairs, where the first member is string with name of the derived parameter and the second should be a functor returned by create_dpf_class @@ -1250,10 +1267,6 @@ def ctor(self): if param_names is not None: body["get_param_names"] = lambda self: param_names - if var_name_types is not None: - body["get_vars"] = \ - lambda self: [Var(*vn) for vn in var_name_types] - if derived_params is not None: body["get_derived_params"] = \ lambda self: [DerivedParam(dp[0], dp[1]) diff --git a/pygenn/src/genn.cc b/pygenn/src/genn.cc index ab588bc3db..5f4141c1a4 100644 --- a/pygenn/src/genn.cc +++ b/pygenn/src/genn.cc @@ -51,17 +51,6 @@ class PySnippet : public SnippetBase virtual Snippet::Base::EGPVec getExtraGlobalParams() const override{ PYBIND11_OVERRIDE_NAME(Snippet::Base::EGPVec, SnippetBase, "get_extra_global_params", getExtraGlobalParams); } }; -//---------------------------------------------------------------------------- -// PyModel -//---------------------------------------------------------------------------- -// 'Trampoline' base class to wrap classes derived off Models::Base -template -class PyModel : public PySnippet -{ -public: - virtual Models::Base::VarVec getVars() const override{ PYBIND11_OVERRIDE_NAME(Models::Base::VarVec, ModelBase, "get_vars", getVars); } -}; - //---------------------------------------------------------------------------- // PyInitSparseConnectivitySnippetBase //---------------------------------------------------------------------------- @@ -106,23 +95,25 @@ class PyInitVarSnippetBase : public PySnippet // PyCurrentSourceModelBase //---------------------------------------------------------------------------- // 'Trampoline' class for current source models -class PyCurrentSourceModelBase : public PyModel +class PyCurrentSourceModelBase : public PySnippet { using Base = CurrentSourceModels::Base; public: virtual std::string getInjectionCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_injection_code", getInjectionCode); } + virtual std::vector getVars() const override{ PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_vars", getVars); } }; //---------------------------------------------------------------------------- // PyCustomConnectivityUpdateModelBase //---------------------------------------------------------------------------- // 'Trampoline' class for custom connectivity update models -class PyCustomConnectivityUpdateModelBase : public PyModel +class PyCustomConnectivityUpdateModelBase : public PySnippet { using Base = CustomConnectivityUpdateModels::Base; public: - virtual VarVec getPreVars() const override { PYBIND11_OVERRIDE_NAME(Models::Base::VarVec, Base, "get_pre_vars", getPreVars); } - virtual VarVec getPostVars() const override { PYBIND11_OVERRIDE_NAME(Models::Base::VarVec, Base, "get_post_vars", getPostVars); } + virtual std::vector getVars() const override{ PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_vars", getVars); } + virtual std::vector getPreVars() const override { PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_pre_vars", getPreVars); } + virtual std::vector getPostVars() const override { PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_post_vars", getPostVars); } virtual VarRefVec getVarRefs() const override { PYBIND11_OVERRIDE_NAME(VarRefVec, Base, "get_var_refs", getVarRefs); } virtual VarRefVec getPreVarRefs() const override { PYBIND11_OVERRIDE_NAME(VarRefVec, Base, "get_pre_var_refs", getPreVarRefs); } @@ -136,10 +127,11 @@ class PyCustomConnectivityUpdateModelBase : public PyModel +class PyCustomUpdateModelBase : public PySnippet { using Base = CustomUpdateModels::Base; public: + virtual std::vector getVars() const override{ PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_vars", getVars); } virtual VarRefVec getVarRefs() const override { PYBIND11_OVERRIDE_NAME(VarRefVec, Base, "get_var_refs", getVarRefs); } virtual EGPRefVec getExtraGlobalParamRefs() const override { PYBIND11_OVERRIDE_NAME(EGPRefVec, Base, "get_extra_global_param_refs", getExtraGlobalParamRefs); } virtual std::string getUpdateCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_update_code", getUpdateCode); } @@ -149,7 +141,7 @@ class PyCustomUpdateModelBase : public PyModel // PyNeuronModelBase //---------------------------------------------------------------------------- // 'Trampoline' class for neuron models -class PyNeuronModelBase : public PyModel +class PyNeuronModelBase : public PySnippet { using Base = NeuronModels::Base; public: @@ -157,6 +149,7 @@ class PyNeuronModelBase : public PyModel virtual std::string getThresholdConditionCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_threshold_condition_code", getThresholdConditionCode); } virtual std::string getResetCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_reset_code", getResetCode); } + virtual std::vector getVars() const override{ PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_vars", getVars); } virtual Models::Base::ParamValVec getAdditionalInputVars() const override { PYBIND11_OVERRIDE_NAME(Models::Base::ParamValVec, Base, "get_additional_input_vars", getAdditionalInputVars); } virtual bool isAutoRefractoryRequired() const override { PYBIND11_OVERRIDE_NAME(bool, Base, "is_auto_refractory_required", isAutoRefractoryRequired); } @@ -166,10 +159,12 @@ class PyNeuronModelBase : public PyModel // PyPostsynapticModelBase //---------------------------------------------------------------------------- // 'Trampoline' class for postsynaptic models -class PyPostsynapticModelBase : public PyModel +class PyPostsynapticModelBase : public PySnippet { using Base = PostsynapticModels::Base; public: + virtual std::vector getVars() const override{ PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_vars", getVars); } + virtual std::string getDecayCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_decay_code", getDecayCode); } virtual std::string getApplyInputCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_apply_input_code", getApplyInputCode); } }; @@ -178,7 +173,7 @@ class PyPostsynapticModelBase : public PyModel // PyWeightUpdateModelBase //---------------------------------------------------------------------------- // 'Trampoline' class for weight update models -class PyWeightUpdateModelBase : public PyModel +class PyWeightUpdateModelBase : public PySnippet { using Base = WeightUpdateModels::Base; public: @@ -191,8 +186,10 @@ class PyWeightUpdateModelBase : public PyModel virtual std::string getPostSpikeCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_post_spike_code", getPostSpikeCode); } virtual std::string getPreDynamicsCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_pre_dynamics_code", getPreDynamicsCode); } virtual std::string getPostDynamicsCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_post_dynamics_code", getPostDynamicsCode); } - virtual VarVec getPreVars() const override { PYBIND11_OVERRIDE_NAME(Models::Base::VarVec, Base, "get_pre_vars", getPreVars); } - virtual VarVec getPostVars() const override { PYBIND11_OVERRIDE_NAME(Models::Base::VarVec, Base, "get_post_vars", getPostVars); } + + virtual std::vector getVars() const override{ PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_vars", getVars); } + virtual std::vector getPreVars() const override { PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_pre_vars", getPreVars); } + virtual std::vector getPostVars() const override { PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_post_vars", getPostVars); } }; CodeGenerator::MemAlloc generateCode(ModelSpecInternal &model, CodeGenerator::BackendBase &backend, @@ -363,6 +360,9 @@ PYBIND11_MODULE(genn, m) m.def("create_egp_ref", pybind11::overload_cast(&createEGPRef), pybind11::return_value_policy::move); m.def("create_psm_egp_ref", pybind11::overload_cast(&createPSMEGPRef), pybind11::return_value_policy::move); m.def("create_wu_egp_ref", pybind11::overload_cast(&createWUEGPRef), pybind11::return_value_policy::move); + m.def("get_var_access_dim", pybind11::overload_cast(&getVarAccessDim)); + m.def("get_var_access_dim", pybind11::overload_cast(&getVarAccessDim)); + m.def("get_var_access_dim", pybind11::overload_cast(&getVarAccessDim)); //------------------------------------------------------------------------ // genn.ModelSpec @@ -693,34 +693,40 @@ PYBIND11_MODULE(genn, m) .def("get_code", &InitVarSnippet::Base::getCode); //------------------------------------------------------------------------ - // genn.VarAccess + // genn.NeuronVar //------------------------------------------------------------------------ - pybind11::class_(m, "VarAccess") - .def("get_neuron_dims", - [](const VarAccess &v) { return v.getDims(); }) - .def("get_synapse_dims", - [](const VarAccess &v) { return v.getDims(); }) - .def("get_custom_update_dims", - [](const VarAccess &v, VarAccessDim cuDims) - { - return clearDim(cuDims, v.getDims()); - }); - + pybind11::class_(m, "NeuronVar") + .def(pybind11::init()) + .def(pybind11::init()) + .def(pybind11::init()) + .def(pybind11::init()) + .def_readonly("name", &Models::Base::NeuronVar::name) + .def_readonly("type", &Models::Base::NeuronVar::type) + .def_readonly("access", &Models::Base::NeuronVar::access); + //------------------------------------------------------------------------ - // genn.Var + // genn.SynapseVar //------------------------------------------------------------------------ - pybind11::class_(m, "Var") - .def(pybind11::init()) + pybind11::class_(m, "SynapseVar") .def(pybind11::init()) - .def(pybind11::init()) .def(pybind11::init()) - .def(pybind11::init()) .def(pybind11::init()) + .def(pybind11::init()) + .def_readonly("name", &Models::Base::SynapseVar::name) + .def_readonly("type", &Models::Base::SynapseVar::type) + .def_readonly("access", &Models::Base::SynapseVar::access); + + //------------------------------------------------------------------------ + // genn.CustomUpdateVar + //------------------------------------------------------------------------ + pybind11::class_(m, "CustomUpdateVar") + .def(pybind11::init()) + .def(pybind11::init()) .def(pybind11::init()) .def(pybind11::init()) - .def_readonly("name", &Models::Base::Var::name) - .def_readonly("type", &Models::Base::Var::type) - .def_readonly("access", &Models::Base::Var::access); + .def_readonly("name", &Models::Base::CustomUpdateVar::name) + .def_readonly("type", &Models::Base::CustomUpdateVar::type) + .def_readonly("access", &Models::Base::CustomUpdateVar::access); //------------------------------------------------------------------------ // genn.VarRef @@ -743,26 +749,22 @@ PYBIND11_MODULE(genn, m) .def_readonly("name", &Models::Base::EGPRef::name) .def_readonly("type", &Models::Base::EGPRef::type); - //------------------------------------------------------------------------ - // genn.ModelBase - //------------------------------------------------------------------------ - pybind11::class_>(m, "ModelBase") - .def("get_vars", &Models::Base::getVars); - //------------------------------------------------------------------------ // genn.CurrentSourceModelBase //------------------------------------------------------------------------ - pybind11::class_(m, "CurrentSourceModelBase") + pybind11::class_(m, "CurrentSourceModelBase") .def(pybind11::init<>()) + .def("get_vars", &CurrentSourceModels::Base::getVars) .def("get_injection_code", &CurrentSourceModels::Base::getInjectionCode); //------------------------------------------------------------------------ // genn.CustomConnectivityUpdateModelBase //------------------------------------------------------------------------ - pybind11::class_(m, "CustomConnectivityUpdateModelBase") + pybind11::class_(m, "CustomConnectivityUpdateModelBase") .def(pybind11::init<>()) + .def("get_vars", &CustomConnectivityUpdateModels::Base::getVars) .def("get_pre_vars", &CustomConnectivityUpdateModels::Base::getPreVars) .def("get_post_vars", &CustomConnectivityUpdateModels::Base::getPostVars) @@ -776,9 +778,10 @@ PYBIND11_MODULE(genn, m) //------------------------------------------------------------------------ // genn.CustomUpdateModelBase //------------------------------------------------------------------------ - pybind11::class_(m, "CustomUpdateModelBase") + pybind11::class_(m, "CustomUpdateModelBase") .def(pybind11::init<>()) - + + .def("get_vars", &CustomUpdateModels::Base::getVars) .def("get_var_refs", &CustomUpdateModels::Base::getVarRefs) .def("get_extra_global_param_refs", &CustomUpdateModels::Base::getExtraGlobalParamRefs) .def("get_update_code", &CustomUpdateModels::Base::getUpdateCode); @@ -786,9 +789,10 @@ PYBIND11_MODULE(genn, m) //------------------------------------------------------------------------ // genn.NeuronModelBase //------------------------------------------------------------------------ - pybind11::class_(m, "NeuronModelBase") + pybind11::class_(m, "NeuronModelBase") .def(pybind11::init<>()) - + + .def("get_vars", &NeuronModels::Base::getVars) .def("get_sim_code", &NeuronModels::Base::getSimCode) .def("get_threshold_condition_code", &NeuronModels::Base::getThresholdConditionCode) .def("get_reset_code", &NeuronModels::Base::getResetCode) @@ -798,16 +802,17 @@ PYBIND11_MODULE(genn, m) //------------------------------------------------------------------------ // genn.PostsynapticModelBase //------------------------------------------------------------------------ - pybind11::class_(m, "PostsynapticModelBase") + pybind11::class_(m, "PostsynapticModelBase") .def(pybind11::init<>()) - + + .def("get_vars", &PostsynapticModels::Base::getVars) .def("get_decay_code", &PostsynapticModels::Base::getDecayCode) .def("get_apply_input_code", &PostsynapticModels::Base::getApplyInputCode); //------------------------------------------------------------------------ // genn.WeightUpdateModelBase //------------------------------------------------------------------------ - pybind11::class_(m, "WeightUpdateModelBase") + pybind11::class_(m, "WeightUpdateModelBase") .def(pybind11::init<>()) .def("get_sim_code", &WeightUpdateModels::Base::getSimCode) @@ -819,6 +824,7 @@ PYBIND11_MODULE(genn, m) .def("get_post_spike_code", &WeightUpdateModels::Base::getPostSpikeCode) .def("get_pre_dynamics_code", &WeightUpdateModels::Base::getPreDynamicsCode) .def("get_post_dynamics_code", &WeightUpdateModels::Base::getPostDynamicsCode) + .def("get_vars", &WeightUpdateModels::Base::getVars) .def("get_pre_vars", &WeightUpdateModels::Base::getPreVars) .def("get_post_vars", &WeightUpdateModels::Base::getPostVars); From 6e74f6abad9f33d62cb7c766a1146f332066e9b0 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 18 Sep 2023 17:11:50 +0100 Subject: [PATCH 54/60] fix std::tieing of temporary variable --- src/genn/genn/models.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index 7577f2c74b..6b20820c3a 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -313,8 +313,12 @@ CustomUpdateWU *WUVarReference::getReferencedCustomUpdate() const //------------------------------------------------------------------------ bool WUVarReference::operator < (const WUVarReference &other) const { - return (std::tie(getVarName(), getTargetName(), getTransposeVarName(), getTransposeTargetName()) - < std::tie(other.getVarName(), other.getTargetName(), other.getTransposeVarName(), other.getTransposeTargetName())); + const auto transposeVarName = getTransposeVarName(); + const auto transposeTargetName = getTransposeTargetName(); + const auto otherTransposeVarName = other.getTransposeVarName(); + const auto otherTransposeTargetName = other.getTransposeTargetName(); + return (std::tie(getVarName(), getTargetName(), transposeVarName, transposeTargetName) + < std::tie(other.getVarName(), other.getTargetName(), otherTransposeVarName, otherTransposeTargetName)); } //------------------------------------------------------------------------ WUVarReference WUVarReference::createWUVarReference(SynapseGroup *sg, const std::string &varName, From 2c2fe38e43523b59f02fbfbe949069397905d8e7 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 18 Sep 2023 17:23:12 +0100 Subject: [PATCH 55/60] fixed wrapping of model classes --- pygenn/src/genn.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pygenn/src/genn.cc b/pygenn/src/genn.cc index 5f4141c1a4..dde31c37c5 100644 --- a/pygenn/src/genn.cc +++ b/pygenn/src/genn.cc @@ -752,7 +752,7 @@ PYBIND11_MODULE(genn, m) //------------------------------------------------------------------------ // genn.CurrentSourceModelBase //------------------------------------------------------------------------ - pybind11::class_(m, "CurrentSourceModelBase") + pybind11::class_(m, "CurrentSourceModelBase") .def(pybind11::init<>()) .def("get_vars", &CurrentSourceModels::Base::getVars) @@ -761,7 +761,7 @@ PYBIND11_MODULE(genn, m) //------------------------------------------------------------------------ // genn.CustomConnectivityUpdateModelBase //------------------------------------------------------------------------ - pybind11::class_(m, "CustomConnectivityUpdateModelBase") + pybind11::class_(m, "CustomConnectivityUpdateModelBase") .def(pybind11::init<>()) .def("get_vars", &CustomConnectivityUpdateModels::Base::getVars) @@ -778,7 +778,7 @@ PYBIND11_MODULE(genn, m) //------------------------------------------------------------------------ // genn.CustomUpdateModelBase //------------------------------------------------------------------------ - pybind11::class_(m, "CustomUpdateModelBase") + pybind11::class_(m, "CustomUpdateModelBase") .def(pybind11::init<>()) .def("get_vars", &CustomUpdateModels::Base::getVars) @@ -789,7 +789,7 @@ PYBIND11_MODULE(genn, m) //------------------------------------------------------------------------ // genn.NeuronModelBase //------------------------------------------------------------------------ - pybind11::class_(m, "NeuronModelBase") + pybind11::class_(m, "NeuronModelBase") .def(pybind11::init<>()) .def("get_vars", &NeuronModels::Base::getVars) @@ -802,7 +802,7 @@ PYBIND11_MODULE(genn, m) //------------------------------------------------------------------------ // genn.PostsynapticModelBase //------------------------------------------------------------------------ - pybind11::class_(m, "PostsynapticModelBase") + pybind11::class_(m, "PostsynapticModelBase") .def(pybind11::init<>()) .def("get_vars", &PostsynapticModels::Base::getVars) @@ -812,7 +812,7 @@ PYBIND11_MODULE(genn, m) //------------------------------------------------------------------------ // genn.WeightUpdateModelBase //------------------------------------------------------------------------ - pybind11::class_(m, "WeightUpdateModelBase") + pybind11::class_(m, "WeightUpdateModelBase") .def(pybind11::init<>()) .def("get_sim_code", &WeightUpdateModels::Base::getSimCode) From c6ab9dfee35a019d4d1b54a509ad9aa5b953f22d Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 18 Sep 2023 17:33:16 +0100 Subject: [PATCH 56/60] fixed typos --- pygenn/genn_model.py | 116 +++++++++++++++++++++---------------------- 1 file changed, 58 insertions(+), 58 deletions(-) diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py index 5541e07e82..1ccda3fff6 100644 --- a/pygenn/genn_model.py +++ b/pygenn/genn_model.py @@ -210,7 +210,7 @@ def __init__(self, precision="float", model_name="GeNNModel", types.Uint16: np.uint16, types.Int8: np.int8, types.Uint8: np.uint8, - types.Bool: np.bool8} + types.Bool: np.bool_} @property def backend_name(self): @@ -851,6 +851,58 @@ class as string or instance of class init_toeplitz_connectivity_snippets) return ToeplitzConnectivityInit(init_toeplitz_connect_snippet, param_space) + +def create_model(class_name, base, param_names, derived_params, + extra_global_params, custom_body): + """This helper function completes a custom model class creation. + + This part is common for all model classes and is nearly useless on its own + unless you specify custom_body. + See also: + create_neuron_model + create_weight_update_model + create_postsynaptic_model + create_current_source_model + create_var_init_snippet + create_sparse_connect_init_snippet + + Args: + class_name -- name of the new class + base -- base class + param_names -- list of strings with param names of the model + derived_params -- list of pairs, where the first member is string with + name of the derived parameter and the second should + be a functor returned by create_dpf_class + extra_global_params -- list of pairs of strings with names and types of + additional parameters + custom_body -- dictionary with attributes and methods of the new class + """ + + def ctor(self): + base.__init__(self) + + body = { + "__init__": ctor, + } + + if param_names is not None: + body["get_param_names"] = lambda self: param_names + + if derived_params is not None: + body["get_derived_params"] = \ + lambda self: [DerivedParam(dp[0], dp[1]) + for dp in derived_params] + + if extra_global_params is not None: + body["get_extra_global_params"] = \ + lambda self: [EGP(egp[0], egp[1]) + for egp in extra_global_params] + + if custom_body is not None: + body.update(custom_body) + + return type(class_name, (base,), body)() + def create_neuron_model(class_name, param_names=None, var_name_types=None, derived_params=None, sim_code=None, threshold_condition_code=None, @@ -1210,11 +1262,11 @@ def create_custom_connectivity_update_model(class_name, if pre_var_name_types is not None: body["get_pre_vars"] = \ - lambda self: [Var(*vn) for vn in pre_var_name_types] + lambda self: [NeuronVar(*vn) for vn in pre_var_name_types] if post_var_name_types is not None: body["get_post_vars"] = \ - lambda self: [Var(*vn) for vn in post_var_name_types] + lambda self: [NeuronVar(*vn) for vn in post_var_name_types] if var_refs is not None: body["get_var_refs"] = lambda self: [VarRef(*v) for v in var_refs] @@ -1231,58 +1283,6 @@ def create_custom_connectivity_update_model(class_name, param_names, derived_params, extra_global_params, body) -def create_model(class_name, base, param_names, derived_params, - extra_global_params, custom_body): - """This helper function completes a custom model class creation. - - This part is common for all model classes and is nearly useless on its own - unless you specify custom_body. - See also: - create_neuron_model - create_weight_update_model - create_postsynaptic_model - create_current_source_model - create_var_init_snippet - create_sparse_connect_init_snippet - - Args: - class_name -- name of the new class - base -- base class - param_names -- list of strings with param names of the model - derived_params -- list of pairs, where the first member is string with - name of the derived parameter and the second should - be a functor returned by create_dpf_class - extra_global_params -- list of pairs of strings with names and types of - additional parameters - custom_body -- dictionary with attributes and methods of the new class - """ - - def ctor(self): - base.__init__(self) - - body = { - "__init__": ctor, - } - - if param_names is not None: - body["get_param_names"] = lambda self: param_names - - if derived_params is not None: - body["get_derived_params"] = \ - lambda self: [DerivedParam(dp[0], dp[1]) - for dp in derived_params] - - if extra_global_params is not None: - body["get_extra_global_params"] = \ - lambda self: [EGP(egp[0], egp[1]) - for egp in extra_global_params] - - if custom_body is not None: - body.update(custom_body) - - return type(class_name, (base,), body)() - - def create_var_init_snippet(class_name, param_names=None, derived_params=None, var_init_code=None, @@ -1314,7 +1314,7 @@ def create_var_init_snippet(class_name, param_names=None, body["get_code"] = lambda self: dedent(var_init_code) return create_model(class_name, InitVarSnippetBase, - param_names, None, derived_params, + param_names, derived_params, extra_global_params, body) @@ -1381,7 +1381,7 @@ def create_sparse_connect_init_snippet(class_name, lambda self: make_cksf(calc_kernel_size_func) return create_model(class_name, InitSparseConnectivitySnippetBase, param_names, - None, derived_params, extra_global_params, body) + derived_params, extra_global_params, body) def create_toeplitz_connect_init_snippet(class_name, param_names=None, @@ -1436,7 +1436,7 @@ def create_toeplitz_connect_init_snippet(class_name, lambda self: make_cksf(calc_kernel_size_func) return create_model(class_name, InitToeplitzConnectivitySnippetBase, param_names, - None, derived_params, extra_global_params, body) + derived_params, extra_global_params, body) @deprecated("this wrapper is now unnecessary - use callables directly") def create_dpf_class(dp_func): From 2fbf52f13c4a8d08f7974530326cb6f01bd60c1e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 26 Sep 2023 11:58:49 +0100 Subject: [PATCH 57/60] add additional test to custom updates to check whether REDUCE access mode is set on variables/variable references in model but no variables with matching shape are referenced --- include/genn/genn/customUpdate.h | 2 ++ src/genn/genn/customUpdate.cc | 35 +++++++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index 6365ea3b35..dc3bb3d144 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -73,6 +73,8 @@ class GENN_EXPORT CustomUpdateBase bool isZeroCopyEnabled() const; + bool isModelReduction() const; + //! Updates hash with custom update /*! NOTE: this can only be called after model is finalized */ void updateHash(boost::uuids::detail::sha1 &hash) const; diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index dee79c24b3..a7eff8f775 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -87,6 +87,27 @@ bool CustomUpdateBase::isZeroCopyEnabled() const [](VarLocation loc) { return (loc & VarLocation::ZERO_COPY); }); } //---------------------------------------------------------------------------- +bool CustomUpdateBase::isModelReduction() const +{ + // Return true if any variables have REDUCE flag in their access mode + const auto vars = getCustomUpdateModel()->getVars(); + if(std::any_of(vars.cbegin(), vars.cend(), + [](const auto &v){ return (v.access & VarAccessModeAttribute::REDUCE); })) + { + return true; + } + + // Return true if any variable references have REDUCE flag in their access mode + const auto varRefs = getCustomUpdateModel()->getVarRefs(); + if(std::any_of(varRefs.cbegin(), varRefs.cend(), + [](const auto &v){ return (v.access & VarAccessModeAttribute::REDUCE); })) + { + return true; + } + + return false; +} +//---------------------------------------------------------------------------- void CustomUpdateBase::updateHash(boost::uuids::detail::sha1 &hash) const { Utils::updateHash(getCustomUpdateModel()->getHashDigest(), hash); @@ -135,9 +156,15 @@ CustomUpdate::CustomUpdate(const std::string &name, const std::string &updateGro Models::checkVarReferenceTypes(m_VarReferences, getCustomUpdateModel()->getVarRefs()); // Check only one type of reduction is specified - if (isBatchReduction() && isNeuronReduction()) { + const bool batchReduction = isBatchReduction(); + const bool neuronReduction = isNeuronReduction(); + if (batchReduction && neuronReduction) { throw std::runtime_error("Custom updates cannot perform batch and neuron reductions simultaneously."); } + // Otherwise, if model specifies reduction operations but none are correctly configured + else if(isModelReduction() && !batchReduction && !neuronReduction) { + throw std::runtime_error("Custom updates uses reduction model but shape is incorrect."); + } // Give error if any sizes differ if(std::any_of(m_VarReferences.cbegin(), m_VarReferences.cend(), @@ -225,6 +252,11 @@ CustomUpdateWU::CustomUpdateWU(const std::string &name, const std::string &updat // Check variable reference types Models::checkVarReferenceTypes(m_VarReferences, getCustomUpdateModel()->getVarRefs()); + // If model specifies reduction operations but none are correctly configured + if(isModelReduction() && !isBatchReduction()) { + throw std::runtime_error("Custom updates uses reduction model but shape is incorrect."); + } + // Give error if references point to different synapse groups // **NOTE** this could be relaxed for dense if(std::any_of(m_VarReferences.cbegin(), m_VarReferences.cend(), @@ -235,6 +267,7 @@ CustomUpdateWU::CustomUpdateWU(const std::string &name, const std::string &updat { throw std::runtime_error("All referenced variables must belong to the same synapse group."); } + // If this is a transpose operation if(isTransposeOperation()) { // Check that it isn't also a reduction From 1410f878d7e7c0e88b0122f73ca4c988886d5a06 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 31 Oct 2023 10:17:33 +0000 Subject: [PATCH 58/60] simplify variable dimensions: * ELEMENT and BATCH dimensions * VarAccess and CustomUpdateVarAccess * Var and CustomUpdateVar --- include/genn/genn/currentSourceInternal.h | 4 +- include/genn/genn/currentSourceModels.h | 4 +- .../genn/customConnectivityUpdateInternal.h | 12 ++-- .../genn/customConnectivityUpdateModels.h | 10 +-- include/genn/genn/customUpdate.h | 2 +- include/genn/genn/models.h | 50 +++++--------- include/genn/genn/neuronGroupInternal.h | 4 +- include/genn/genn/neuronModels.h | 22 +++---- include/genn/genn/postsynapticModels.h | 2 +- include/genn/genn/synapseGroupInternal.h | 16 ++--- include/genn/genn/varAccess.h | 66 ++++--------------- include/genn/genn/weightUpdateModels.h | 18 ++--- .../backends/single_threaded_cpu/backend.cc | 4 +- src/genn/genn/code_generator/backendBase.cc | 4 +- src/genn/genn/code_generator/backendSIMT.cc | 14 ++-- .../code_generator/customUpdateGroupMerged.cc | 4 +- .../genn/code_generator/generateRunner.cc | 17 +---- .../genn/code_generator/initGroupMerged.cc | 14 ++-- .../code_generator/neuronUpdateGroupMerged.cc | 40 +++++------ .../presynapticUpdateStrategySIMT.cc | 10 +-- .../synapseUpdateGroupMerged.cc | 20 +++--- src/genn/genn/models.cc | 9 +-- tests/unit/customConnectivityUpdate.cc | 10 +-- tests/unit/customUpdate.cc | 44 ++++++------- tests/unit/modelSpec.cc | 4 +- tests/unit/modelSpecMerged.cc | 6 +- tests/unit/models.cc | 6 +- tests/unit/neuronGroup.cc | 10 +-- tests/unit/neuronModels.cc | 2 +- tests/unit/synapseGroup.cc | 16 ++--- tests/unit/weightUpdateModels.cc | 4 +- 31 files changed, 189 insertions(+), 259 deletions(-) diff --git a/include/genn/genn/currentSourceInternal.h b/include/genn/genn/currentSourceInternal.h index facc48daa2..2fffcd3b75 100644 --- a/include/genn/genn/currentSourceInternal.h +++ b/include/genn/genn/currentSourceInternal.h @@ -46,7 +46,7 @@ class CurrentSourceVarAdapter //---------------------------------------------------------------------------- VarLocation getLoc(const std::string &varName) const{ return m_CS.getVarLocation(varName); } - std::vector getDefs() const{ return m_CS.getCurrentSourceModel()->getVars(); } + std::vector getDefs() const{ return m_CS.getCurrentSourceModel()->getVars(); } const std::unordered_map &getInitialisers() const{ return m_CS.getVarInitialisers(); } @@ -54,7 +54,7 @@ class CurrentSourceVarAdapter const std::string &getNameSuffix() const{ return m_CS.getName(); } - VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getVarAccessDim(var.access); } + VarAccessDim getVarDims(const Models::Base::Var &var) const{ return getVarAccessDim(var.access); } private: //---------------------------------------------------------------------------- diff --git a/include/genn/genn/currentSourceModels.h b/include/genn/genn/currentSourceModels.h index 65709d8253..2a66a8eb39 100644 --- a/include/genn/genn/currentSourceModels.h +++ b/include/genn/genn/currentSourceModels.h @@ -33,7 +33,7 @@ class GENN_EXPORT Base : public Models::Base virtual std::string getInjectionCode() const{ return ""; } //! Gets model variables - virtual std::vector getVars() const{ return {}; } + virtual std::vector getVars() const{ return {}; } //---------------------------------------------------------------------------- // Public API @@ -113,7 +113,7 @@ class PoissonExp : public Base "current *= ExpDecay;\n"); SET_PARAM_NAMES({"weight", "tauSyn", "rate"}); - SET_NEURON_VARS({{"current", "scalar"}}); + SET_VARS({{"current", "scalar"}}); SET_DERIVED_PARAMS({ {"ExpDecay", [](const std::unordered_map &pars, double dt){ return std::exp(-dt / pars.at("tauSyn")); }}, {"Init", [](const std::unordered_map &pars, double dt){ return pars.at("weight") * (1.0 - std::exp(-dt / pars.at("tauSyn"))) * (pars.at("tauSyn") / dt); }}, diff --git a/include/genn/genn/customConnectivityUpdateInternal.h b/include/genn/genn/customConnectivityUpdateInternal.h index 557b057a5d..1d56138c2a 100644 --- a/include/genn/genn/customConnectivityUpdateInternal.h +++ b/include/genn/genn/customConnectivityUpdateInternal.h @@ -55,13 +55,13 @@ class CustomConnectivityUpdateVarAdapter //---------------------------------------------------------------------------- VarLocation getLoc(const std::string &varName) const{ return m_CU.getVarLocation(varName); } - std::vector getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getVars(); } + std::vector getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getVars(); } const std::unordered_map &getInitialisers() const{ return m_CU.getVarInitialisers(); } const std::string &getNameSuffix() const{ return m_CU.getName(); } - VarAccessDim getVarDims(const Models::Base::SynapseVar &var) const{ return getVarAccessDim(var.access); } + VarAccessDim getVarDims(const Models::Base::Var &var) const{ return getVarAccessDim(var.access); } private: //---------------------------------------------------------------------------- @@ -84,7 +84,7 @@ class CustomConnectivityUpdatePreVarAdapter //---------------------------------------------------------------------------- VarLocation getLoc(const std::string &varName) const{ return m_CU.getPreVarLocation(varName); } - std::vector getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getPreVars(); } + std::vector getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getPreVars(); } const std::unordered_map &getInitialisers() const{ return m_CU.getPreVarInitialisers(); } @@ -92,7 +92,7 @@ class CustomConnectivityUpdatePreVarAdapter const std::string &getNameSuffix() const{ return m_CU.getName(); } - VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getVarAccessDim(var.access); } + VarAccessDim getVarDims(const Models::Base::Var &var) const{ return getVarAccessDim(var.access); } private: //---------------------------------------------------------------------------- @@ -115,7 +115,7 @@ class CustomConnectivityUpdatePostVarAdapter //---------------------------------------------------------------------------- VarLocation getLoc(const std::string &varName) const{ return m_CU.getPostVarLocation(varName); } - std::vector getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getPostVars(); } + std::vector getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getPostVars(); } const std::unordered_map &getInitialisers() const{ return m_CU.getPostVarInitialisers(); } @@ -123,7 +123,7 @@ class CustomConnectivityUpdatePostVarAdapter const std::string &getNameSuffix() const{ return m_CU.getName(); } - VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getVarAccessDim(var.access); } + VarAccessDim getVarDims(const Models::Base::Var &var) const{ return getVarAccessDim(var.access); } private: //---------------------------------------------------------------------------- diff --git a/include/genn/genn/customConnectivityUpdateModels.h b/include/genn/genn/customConnectivityUpdateModels.h index f97cc4a27a..8083029178 100644 --- a/include/genn/genn/customConnectivityUpdateModels.h +++ b/include/genn/genn/customConnectivityUpdateModels.h @@ -7,8 +7,8 @@ //---------------------------------------------------------------------------- // Macros //---------------------------------------------------------------------------- -#define SET_PRE_VARS(...) virtual std::vector getPreVars() const override{ return __VA_ARGS__; } -#define SET_POST_VARS(...) virtual std::vector getPostVars() const override{ return __VA_ARGS__; } +#define SET_PRE_VARS(...) virtual std::vector getPreVars() const override{ return __VA_ARGS__; } +#define SET_POST_VARS(...) virtual std::vector getPostVars() const override{ return __VA_ARGS__; } #define SET_VAR_REFS(...) virtual VarRefVec getVarRefs() const override{ return __VA_ARGS__; } #define SET_PRE_VAR_REFS(...) virtual VarRefVec getPreVarRefs() const override{ return __VA_ARGS__; } @@ -31,14 +31,14 @@ class GENN_EXPORT Base : public Models::Base //---------------------------------------------------------------------------- //! Gets names and types (as strings) of state variables that are common //! across all synapses coming from the same presynaptic neuron - virtual std::vector getPreVars() const { return {}; } + virtual std::vector getPreVars() const { return {}; } //! Gets names and types (as strings) of state variables that are common //! across all synapses going to the same postsynaptic neuron - virtual std::vector getPostVars() const { return {}; } + virtual std::vector getPostVars() const { return {}; } //! Gets model variables - virtual std::vector getVars() const{ return {}; } + virtual std::vector getVars() const{ return {}; } //! Gets names and types (as strings) of synapse variable references virtual VarRefVec getVarRefs() const { return {}; } diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index dc3bb3d144..05ff20e8a2 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -272,7 +272,7 @@ class GENN_EXPORT CustomUpdate : public CustomUpdateBase // Protected const methods //------------------------------------------------------------------------ bool isBatchReduction() const { return isReduction(getVarReferences(), VarAccessDim::BATCH); } - bool isNeuronReduction() const { return isReduction(getVarReferences(), VarAccessDim::NEURON); } + bool isNeuronReduction() const { return isReduction(getVarReferences(), VarAccessDim::ELEMENT); } const NeuronGroup *getDelayNeuronGroup() const { return m_DelayNeuronGroup; } diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 4ab2b4711b..e9c60a5b54 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -33,8 +33,7 @@ class CustomConnectivityUpdateInternal; //---------------------------------------------------------------------------- // Macros //---------------------------------------------------------------------------- -#define SET_NEURON_VARS(...) virtual std::vector getVars() const override{ return __VA_ARGS__; } -#define SET_SYNAPSE_VARS(...) virtual std::vector getVars() const override{ return __VA_ARGS__; } +#define SET_VARS(...) virtual std::vector getVars() const override{ return __VA_ARGS__; } #define SET_CUSTOM_UPDATE_VARS(...) virtual std::vector getVars() const override{ return __VA_ARGS__; } #define DEFINE_REF_DETAIL_STRUCT(NAME, GROUP_TYPE, VAR_TYPE) using NAME = Detail @@ -76,27 +75,15 @@ class GENN_EXPORT Base : public Snippet::Base A access; }; - struct NeuronVar : public VarBase + struct Var : public VarBase { - using VarBase::VarBase; + using VarBase::VarBase; - NeuronVar(const std::string &n, const Type::ResolvedType &t) - : VarBase(n, t, NeuronVarAccess::READ_WRITE) + Var(const std::string &n, const Type::ResolvedType &t) + : VarBase(n, t, VarAccess::READ_WRITE) {} - NeuronVar(const std::string &n, const std::string &t) - : VarBase(n, t, NeuronVarAccess::READ_WRITE) - {} - }; - - struct SynapseVar : public VarBase - { - using VarBase::VarBase; - - SynapseVar(const std::string &n, const Type::ResolvedType &t) - : VarBase(n, t, SynapseVarAccess::READ_WRITE) - {} - SynapseVar(const std::string &n, const std::string &t) - : VarBase(n, t, SynapseVarAccess::READ_WRITE) + Var(const std::string &n, const std::string &t) + : VarBase(n, t, VarAccess::READ_WRITE) {} }; @@ -223,14 +210,14 @@ class GENN_EXPORT VarReference : public VarReferenceBase //------------------------------------------------------------------------ // Typedefines //------------------------------------------------------------------------ - DEFINE_REF_DETAIL_STRUCT(NGRef, NeuronGroupInternal, Base::NeuronVar); - DEFINE_REF_DETAIL_STRUCT(PSMRef, SynapseGroupInternal, Base::NeuronVar); - DEFINE_REF_DETAIL_STRUCT(WUPreRef, SynapseGroupInternal, Base::NeuronVar); - DEFINE_REF_DETAIL_STRUCT(WUPostRef, SynapseGroupInternal, Base::NeuronVar); - DEFINE_REF_DETAIL_STRUCT(CSRef, CurrentSourceInternal, Base::NeuronVar); + DEFINE_REF_DETAIL_STRUCT(NGRef, NeuronGroupInternal, Base::Var); + DEFINE_REF_DETAIL_STRUCT(PSMRef, SynapseGroupInternal, Base::Var); + DEFINE_REF_DETAIL_STRUCT(WUPreRef, SynapseGroupInternal, Base::Var); + DEFINE_REF_DETAIL_STRUCT(WUPostRef, SynapseGroupInternal, Base::Var); + DEFINE_REF_DETAIL_STRUCT(CSRef, CurrentSourceInternal, Base::Var); DEFINE_REF_DETAIL_STRUCT(CURef, CustomUpdateInternal, Base::CustomUpdateVar); - DEFINE_REF_DETAIL_STRUCT(CCUPreRef, CustomConnectivityUpdateInternal, Base::NeuronVar); - DEFINE_REF_DETAIL_STRUCT(CCUPostRef, CustomConnectivityUpdateInternal, Base::NeuronVar); + DEFINE_REF_DETAIL_STRUCT(CCUPreRef, CustomConnectivityUpdateInternal, Base::Var); + DEFINE_REF_DETAIL_STRUCT(CCUPostRef, CustomConnectivityUpdateInternal, Base::Var); //! Variant type used to store 'detail' using DetailType = std::variant transposeVar; + Base::Var var; + std::optional transposeVar; }; //------------------------------------------------------------------------ // Typedefines //------------------------------------------------------------------------ DEFINE_REF_DETAIL_STRUCT(CURef, CustomUpdateWUInternal, Base::CustomUpdateVar); - DEFINE_REF_DETAIL_STRUCT(CCURef, CustomConnectivityUpdateInternal, Base::SynapseVar); + DEFINE_REF_DETAIL_STRUCT(CCURef, CustomConnectivityUpdateInternal, Base::Var); //! Variant type used to store 'detail' using DetailType = std::variant; @@ -376,8 +363,7 @@ class GENN_EXPORT EGPReference //---------------------------------------------------------------------------- // updateHash overrides //---------------------------------------------------------------------------- -GENN_EXPORT void updateHash(const Base::NeuronVar &v, boost::uuids::detail::sha1 &hash); -GENN_EXPORT void updateHash(const Base::SynapseVar &v, boost::uuids::detail::sha1 &hash); +GENN_EXPORT void updateHash(const Base::Var &v, boost::uuids::detail::sha1 &hash); GENN_EXPORT void updateHash(const Base::CustomUpdateVar &v, boost::uuids::detail::sha1 &hash); GENN_EXPORT void updateHash(const Base::VarRef &v, boost::uuids::detail::sha1 &hash); GENN_EXPORT void updateHash(const Base::EGPRef &e, boost::uuids::detail::sha1 &hash); diff --git a/include/genn/genn/neuronGroupInternal.h b/include/genn/genn/neuronGroupInternal.h index 3f0f466cbd..61313c6bd7 100644 --- a/include/genn/genn/neuronGroupInternal.h +++ b/include/genn/genn/neuronGroupInternal.h @@ -70,7 +70,7 @@ class NeuronVarAdapter //---------------------------------------------------------------------------- VarLocation getLoc(const std::string &varName) const{ return m_NG.getVarLocation(varName); } - std::vector getDefs() const{ return m_NG.getNeuronModel()->getVars(); } + std::vector getDefs() const{ return m_NG.getNeuronModel()->getVars(); } const std::unordered_map &getInitialisers() const{ return m_NG.getVarInitialisers(); } @@ -78,7 +78,7 @@ class NeuronVarAdapter const std::string &getNameSuffix() const{ return m_NG.getName(); } - VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getVarAccessDim(var.access); } + VarAccessDim getVarDims(const Models::Base::Var &var) const{ return getVarAccessDim(var.access); } private: //---------------------------------------------------------------------------- diff --git a/include/genn/genn/neuronModels.h b/include/genn/genn/neuronModels.h index 1e2dccb3be..de703bf28c 100644 --- a/include/genn/genn/neuronModels.h +++ b/include/genn/genn/neuronModels.h @@ -35,7 +35,7 @@ class GENN_EXPORT Base : public Models::Base // Declared virtuals //---------------------------------------------------------------------------- //! Gets model variables - virtual std::vector getVars() const{ return {}; } + virtual std::vector getVars() const{ return {}; } //! Gets the code that defines the execution of one timestep of integration of the neuron model. /*! The code will refer to $(NN) for the value of the variable with name "NN". @@ -128,7 +128,7 @@ class RulkovMap : public Base SET_THRESHOLD_CONDITION_CODE("$(V) >= $(ip2)"); SET_PARAM_NAMES({"Vspike", "alpha", "y", "beta"}); - SET_NEURON_VARS({{"V","scalar"}, {"preV", "scalar"}}); + SET_VARS({{"V","scalar"}, {"preV", "scalar"}}); SET_DERIVED_PARAMS({ {"ip0", [](const std::unordered_map &pars, double){ return pars.at("Vspike") * pars.at("Vspike") * pars.at("alpha"); }}, @@ -177,7 +177,7 @@ class Izhikevich : public Base SET_THRESHOLD_CONDITION_CODE("$(V) >= 29.99"); SET_PARAM_NAMES({"a", "b", "c", "d"}); - SET_NEURON_VARS({{"V","scalar"}, {"U", "scalar"}}); + SET_VARS({{"V","scalar"}, {"U", "scalar"}}); SET_NEEDS_AUTO_REFRACTORY(false); }; @@ -205,9 +205,9 @@ class IzhikevichVariable : public Izhikevich DECLARE_SNIPPET(NeuronModels::IzhikevichVariable); SET_PARAM_NAMES({}); - SET_NEURON_VARS({{"V","scalar"}, {"U", "scalar"}, - {"a", "scalar", NeuronVarAccess::READ_ONLY}, {"b", "scalar", NeuronVarAccess::READ_ONLY}, - {"c", "scalar", NeuronVarAccess::READ_ONLY}, {"d", "scalar", NeuronVarAccess::READ_ONLY}}); + SET_VARS({{"V","scalar"}, {"U", "scalar"}, + {"a", "scalar", VarAccess::READ_ONLY}, {"b", "scalar", VarAccess::READ_ONLY}, + {"c", "scalar", VarAccess::READ_ONLY}, {"d", "scalar", VarAccess::READ_ONLY}}); }; //---------------------------------------------------------------------------- @@ -247,7 +247,7 @@ class LIF : public Base {"ExpTC", [](const std::unordered_map &pars, double dt){ return std::exp(-dt / pars.at("TauM")); }}, {"Rmembrane", [](const std::unordered_map &pars, double){ return pars.at("TauM") / pars.at("C"); }}}); - SET_NEURON_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); + SET_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); SET_NEEDS_AUTO_REFRACTORY(false); }; @@ -291,7 +291,7 @@ class SpikeSourceArray : public Base "$(startSpike) != $(endSpike) && " "$(t) >= $(spikeTimes)[$(startSpike)]" ); SET_RESET_CODE( "$(startSpike)++;\n" ); - SET_NEURON_VARS({{"startSpike", "unsigned int"}, {"endSpike", "unsigned int", NeuronVarAccess::READ_ONLY_DUPLICATE}}); + SET_VARS({{"startSpike", "unsigned int"}, {"endSpike", "unsigned int", VarAccess::READ_ONLY_DUPLICATE}}); SET_EXTRA_GLOBAL_PARAMS( {{"spikeTimes", "scalar*"}} ); SET_NEEDS_AUTO_REFRACTORY(false); }; @@ -351,7 +351,7 @@ class Poisson : public Base SET_THRESHOLD_CONDITION_CODE("$(V) >= $(Vspike)"); SET_PARAM_NAMES({"trefract", "tspike", "Vspike", "Vrest"}); - SET_NEURON_VARS({{"V", "scalar"}, {"spikeTime", "scalar"}}); + SET_VARS({{"V", "scalar"}, {"spikeTime", "scalar"}}); SET_EXTRA_GLOBAL_PARAMS({{"firingProb", "scalar*"}, {"offset", "unsigned int"}}); }; @@ -387,7 +387,7 @@ class PoissonNew : public Base SET_THRESHOLD_CONDITION_CODE("$(timeStepToSpike) <= 0.0"); SET_PARAM_NAMES({"rate"}); - SET_NEURON_VARS({{"timeStepToSpike", "scalar"}}); + SET_VARS({{"timeStepToSpike", "scalar"}}); SET_DERIVED_PARAMS({{"isi", [](const std::unordered_map &pars, double dt){ return 1000.0 / (pars.at("rate") * dt); }}}); SET_NEEDS_AUTO_REFRACTORY(false); }; @@ -485,7 +485,7 @@ class TraubMiles : public Base SET_THRESHOLD_CONDITION_CODE("$(V) >= 0.0"); SET_PARAM_NAMES({"gNa", "ENa", "gK", "EK", "gl", "El", "C"}); - SET_NEURON_VARS({{"V", "scalar"}, {"m", "scalar"}, {"h", "scalar"}, {"n", "scalar"}}); + SET_VARS({{"V", "scalar"}, {"m", "scalar"}, {"h", "scalar"}, {"n", "scalar"}}); }; //---------------------------------------------------------------------------- diff --git a/include/genn/genn/postsynapticModels.h b/include/genn/genn/postsynapticModels.h index abbac6d73b..78ed763d43 100644 --- a/include/genn/genn/postsynapticModels.h +++ b/include/genn/genn/postsynapticModels.h @@ -26,7 +26,7 @@ class GENN_EXPORT Base : public Models::Base // Declared virtuals //---------------------------------------------------------------------------- //! Gets model variables - virtual std::vector getVars() const{ return {}; } + virtual std::vector getVars() const{ return {}; } virtual std::string getDecayCode() const{ return ""; } virtual std::string getApplyInputCode() const{ return ""; } diff --git a/include/genn/genn/synapseGroupInternal.h b/include/genn/genn/synapseGroupInternal.h index 99e436bd90..8cf2a3e8a7 100644 --- a/include/genn/genn/synapseGroupInternal.h +++ b/include/genn/genn/synapseGroupInternal.h @@ -110,7 +110,7 @@ class SynapsePSMVarAdapter //---------------------------------------------------------------------------- VarLocation getLoc(const std::string &varName) const{ return m_SG.getPSVarLocation(varName); } - std::vector getDefs() const{ return m_SG.getPSModel()->getVars(); } + std::vector getDefs() const{ return m_SG.getPSModel()->getVars(); } const std::unordered_map &getInitialisers() const{ return m_SG.getPSVarInitialisers(); } @@ -118,7 +118,7 @@ class SynapsePSMVarAdapter bool isVarDelayed(const std::string &) const { return false; } - VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getVarAccessDim(var.access); } + VarAccessDim getVarDims(const Models::Base::Var &var) const{ return getVarAccessDim(var.access); } private: //---------------------------------------------------------------------------- @@ -164,13 +164,13 @@ class SynapseWUVarAdapter //---------------------------------------------------------------------------- VarLocation getLoc(const std::string &varName) const{ return m_SG.getWUVarLocation(varName); } - std::vector getDefs() const{ return m_SG.getWUModel()->getVars(); } + std::vector getDefs() const{ return m_SG.getWUModel()->getVars(); } const std::unordered_map &getInitialisers() const{ return m_SG.getWUVarInitialisers(); } const std::string &getNameSuffix() const{ return m_SG.getName(); } - VarAccessDim getVarDims(const Models::Base::SynapseVar &var) const{ return getVarAccessDim(var.access); } + VarAccessDim getVarDims(const Models::Base::Var &var) const{ return getVarAccessDim(var.access); } private: //---------------------------------------------------------------------------- @@ -193,7 +193,7 @@ class SynapseWUPreVarAdapter //---------------------------------------------------------------------------- VarLocation getLoc(const std::string &varName) const{ return m_SG.getWUPreVarLocation(varName); } - std::vector getDefs() const{ return m_SG.getWUModel()->getPreVars(); } + std::vector getDefs() const{ return m_SG.getWUModel()->getPreVars(); } const std::unordered_map &getInitialisers() const{ return m_SG.getWUPreVarInitialisers(); } @@ -201,7 +201,7 @@ class SynapseWUPreVarAdapter bool isVarDelayed(const std::string&) const{ return (m_SG.getDelaySteps() != 0); } - VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getVarAccessDim(var.access); } + VarAccessDim getVarDims(const Models::Base::Var &var) const{ return getVarAccessDim(var.access); } private: //---------------------------------------------------------------------------- @@ -224,7 +224,7 @@ class SynapseWUPostVarAdapter //---------------------------------------------------------------------------- VarLocation getLoc(const std::string &varName) const{ return m_SG.getWUPostVarLocation(varName); } - std::vector getDefs() const{ return m_SG.getWUModel()->getPostVars(); } + std::vector getDefs() const{ return m_SG.getWUModel()->getPostVars(); } const std::unordered_map &getInitialisers() const{ return m_SG.getWUPostVarInitialisers(); } @@ -232,7 +232,7 @@ class SynapseWUPostVarAdapter bool isVarDelayed(const std::string&) const{ return (m_SG.getBackPropDelaySteps() != 0); } - VarAccessDim getVarDims(const Models::Base::NeuronVar &var) const{ return getVarAccessDim(var.access); } + VarAccessDim getVarDims(const Models::Base::Var &var) const{ return getVarAccessDim(var.access); } private: //---------------------------------------------------------------------------- diff --git a/include/genn/genn/varAccess.h b/include/genn/genn/varAccess.h index a8c47262bc..b90dab8494 100644 --- a/include/genn/genn/varAccess.h +++ b/include/genn/genn/varAccess.h @@ -35,40 +35,19 @@ enum class VarAccessMode : unsigned int //! Flags defining dimensions this variables has enum class VarAccessDim : unsigned int { - NEURON = (1 << 5), - PRE_NEURON = (1 << 6), - POST_NEURON = (1 << 7), - BATCH = (1 << 8), + ELEMENT = (1 << 5), + BATCH = (1 << 6), }; -//! Supported combinations of access mode and dimension for neuron variables -enum class NeuronVarAccess : unsigned int +//! Supported combinations of access mode and dimension for neuron and synapse variables +enum class VarAccess : unsigned int { - READ_WRITE = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDim::NEURON) | static_cast(VarAccessDim::BATCH), - READ_ONLY = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::NEURON), - READ_ONLY_DUPLICATE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::NEURON) | static_cast(VarAccessDim::BATCH), + READ_WRITE = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDim::ELEMENT) | static_cast(VarAccessDim::BATCH), + READ_ONLY = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::ELEMENT), + READ_ONLY_DUPLICATE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::ELEMENT) | static_cast(VarAccessDim::BATCH), READ_ONLY_SHARED_NEURON = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::BATCH), }; -//! Supported combinations of access mode and dimension for synapse variables -enum class SynapseVarAccess : unsigned int -{ - // Synaptic variables - READ_WRITE = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDim::PRE_NEURON) | static_cast(VarAccessDim::POST_NEURON) | static_cast(VarAccessDim::BATCH), - READ_ONLY = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::PRE_NEURON) | static_cast(VarAccessDim::POST_NEURON), - READ_ONLY_DUPLICATE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::PRE_NEURON) | static_cast(VarAccessDim::POST_NEURON) | static_cast(VarAccessDim::BATCH), - - // Presynaptic variables - //READ_WRITE_PRE = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDim::PRE_NEURON) | static_cast(VarAccessDim::BATCH), - //READ_ONLY_PRE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::PRE_NEURON), - //READ_ONLY_PRE_DUPLICATE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::PRE_NEURON) | static_cast(VarAccessDim::BATCH), - - // Postsynaptic variables - //READ_WRITE_POST = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDim::POST_NEURON) | static_cast(VarAccessDim::BATCH), - //READ_ONLY_POST = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::POST_NEURON), - //READ_ONLY_POST_DUPLICATE = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::POST_NEURON) | static_cast(VarAccessDim::BATCH), -}; - //! Supported combinations of access mode and dimension for custom update variables /*! The axes are defined 'subtractively' ie VarAccessDim::BATCH indicates that this axis should be removed */ enum class CustomUpdateVarAccess : unsigned int @@ -78,18 +57,16 @@ enum class CustomUpdateVarAccess : unsigned int READ_ONLY = static_cast(VarAccessMode::READ_ONLY), // Variables which will be shared across batches if custom update is batched - READ_WRITE_SHARED = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDim::BATCH), READ_ONLY_SHARED = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::BATCH), - // Variables which will be shared across neurons if per-neuron - READ_WRITE_SHARED_NEURON = static_cast(VarAccessMode::READ_WRITE) | static_cast(VarAccessDim::NEURON), - READ_ONLY_SHARED_NEURON = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::NEURON), + // Variables which will be shared across neurons if per-element + READ_ONLY_SHARED_ELEMENT = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::ELEMENT), // Reduction variables REDUCE_BATCH_SUM = static_cast(VarAccessMode::REDUCE_SUM) | static_cast(VarAccessDim::BATCH), REDUCE_BATCH_MAX = static_cast(VarAccessMode::REDUCE_MAX) | static_cast(VarAccessDim::BATCH), - REDUCE_NEURON_SUM = static_cast(VarAccessMode::REDUCE_SUM) | static_cast(VarAccessDim::NEURON), - REDUCE_NEURON_MAX = static_cast(VarAccessMode::REDUCE_MAX) | static_cast(VarAccessDim::NEURON), + REDUCE_ELEMENT_SUM = static_cast(VarAccessMode::REDUCE_SUM) | static_cast(VarAccessDim::ELEMENT), + REDUCE_ELEMENT_MAX = static_cast(VarAccessMode::REDUCE_MAX) | static_cast(VarAccessDim::ELEMENT), }; //---------------------------------------------------------------------------- @@ -100,12 +77,7 @@ inline bool operator & (VarAccessMode mode, VarAccessModeAttribute modeAttribute return (static_cast(mode) & static_cast(modeAttribute)) != 0; } -inline bool operator & (NeuronVarAccess mode, VarAccessModeAttribute modeAttribute) -{ - return (static_cast(mode) & static_cast(modeAttribute)) != 0; -} - -inline bool operator & (SynapseVarAccess mode, VarAccessModeAttribute modeAttribute) +inline bool operator & (VarAccess mode, VarAccessModeAttribute modeAttribute) { return (static_cast(mode) & static_cast(modeAttribute)) != 0; } @@ -133,12 +105,7 @@ inline VarAccessDim clearVarAccessDim(VarAccessDim a, VarAccessDim b) return static_cast(static_cast(a) & ~static_cast(b)); } -inline VarAccessDim getVarAccessDim(NeuronVarAccess v) -{ - return static_cast(static_cast(v) & ~0x1F); -} - -inline VarAccessDim getVarAccessDim(SynapseVarAccess v) +inline VarAccessDim getVarAccessDim(VarAccess v) { return static_cast(static_cast(v) & ~0x1F); } @@ -153,12 +120,7 @@ inline VarAccessMode getVarAccessMode(VarAccessMode v) return v; } -inline VarAccessMode getVarAccessMode(NeuronVarAccess v) -{ - return static_cast(static_cast(v) & 0x1F); -} - -inline VarAccessMode getVarAccessMode(SynapseVarAccess v) +inline VarAccessMode getVarAccessMode(VarAccess v) { return static_cast(static_cast(v) & 0x1F); } diff --git a/include/genn/genn/weightUpdateModels.h b/include/genn/genn/weightUpdateModels.h index 3c2538ab91..1ed035067c 100644 --- a/include/genn/genn/weightUpdateModels.h +++ b/include/genn/genn/weightUpdateModels.h @@ -17,8 +17,8 @@ #define SET_PRE_DYNAMICS_CODE(PRE_DYNAMICS_CODE) virtual std::string getPreDynamicsCode() const override{ return PRE_DYNAMICS_CODE; } #define SET_POST_DYNAMICS_CODE(POST_DYNAMICS_CODE) virtual std::string getPostDynamicsCode() const override{ return POST_DYNAMICS_CODE; } -#define SET_PRE_VARS(...) virtual std::vector getPreVars() const override{ return __VA_ARGS__; } -#define SET_POST_VARS(...) virtual std::vector getPostVars() const override{ return __VA_ARGS__; } +#define SET_PRE_VARS(...) virtual std::vector getPreVars() const override{ return __VA_ARGS__; } +#define SET_POST_VARS(...) virtual std::vector getPostVars() const override{ return __VA_ARGS__; } //---------------------------------------------------------------------------- // GeNN::WeightUpdateModels::Base @@ -72,15 +72,15 @@ class GENN_EXPORT Base : public Models::Base virtual std::string getPostDynamicsCode() const{ return ""; } //! Gets model variables - virtual std::vector getVars() const{ return {}; } + virtual std::vector getVars() const{ return {}; } //! Gets names and types (as strings) of state variables that are common //! across all synapses coming from the same presynaptic neuron - virtual std::vector getPreVars() const{ return {}; } + virtual std::vector getPreVars() const{ return {}; } //! Gets names and types (as strings) of state variables that are common //! across all synapses going to the same postsynaptic neuron - virtual std::vector getPostVars() const{ return {}; } + virtual std::vector getPostVars() const{ return {}; } //------------------------------------------------------------------------ // Public methods @@ -140,7 +140,7 @@ class StaticPulse : public Base public: DECLARE_SNIPPET(StaticPulse); - SET_SYNAPSE_VARS({{"g", "scalar", SynapseVarAccess::READ_ONLY}}); + SET_VARS({{"g", "scalar", VarAccess::READ_ONLY}}); SET_SIM_CODE("addToPost(g);\n"); }; @@ -191,7 +191,7 @@ class StaticPulseDendriticDelay : public Base public: DECLARE_SNIPPET(StaticPulseDendriticDelay); - SET_SYNAPSE_VARS({{"g", "scalar", SynapseVarAccess::READ_ONLY}, {"d", "uint8_t", SynapseVarAccess::READ_ONLY}}); + SET_VARS({{"g", "scalar", VarAccess::READ_ONLY}, {"d", "uint8_t", VarAccess::READ_ONLY}}); SET_SIM_CODE("addToPostDelay(g, d);\n"); }; @@ -228,7 +228,7 @@ class StaticGraded : public Base DECLARE_SNIPPET(StaticGraded); SET_PARAM_NAMES({"Epre", "Vslope"}); - SET_SYNAPSE_VARS({{"g", "scalar", SynapseVarAccess::READ_ONLY}}); + SET_VARS({{"g", "scalar", VarAccess::READ_ONLY}}); SET_EVENT_CODE("addToPost(fmax(0.0, g * tanh((V_pre - Epre) / Vslope) * DT));\n"); @@ -299,7 +299,7 @@ class PiecewiseSTDP : public Base SET_PARAM_NAMES({"tLrn", "tChng", "tDecay", "tPunish10", "tPunish01", "gMax", "gMid", "gSlope", "tauShift", "gSyn0"}); - SET_SYNAPSE_VARS({{"g", "scalar"}, {"gRaw", "scalar"}}); + SET_VARS({{"g", "scalar"}, {"gRaw", "scalar"}}); SET_SIM_CODE( "addToPost(g);\n" diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 5a225c89b8..34d163699b 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -583,7 +583,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back // Loop through group members EnvironmentGroupMergedField memberEnv(groupEnv, c); - if (c.getArchetype().getDims() & VarAccessDim::NEURON) { + if (c.getArchetype().getDims() & VarAccessDim::ELEMENT) { memberEnv.print("for(unsigned int i = 0; i < $(size); i++)"); memberEnv.add(Type::Uint32.addConst(), "id", "i"); } @@ -611,7 +611,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back else { // Loop through group members EnvironmentGroupMergedField memberEnv(groupEnv, c); - if (c.getArchetype().getDims() & VarAccessDim::NEURON) { + if (c.getArchetype().getDims() & VarAccessDim::ELEMENT) { memberEnv.print("for(unsigned int i = 0; i < $(size); i++)"); memberEnv.add(Type::Uint32.addConst(), "id", "i"); } diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 3be1e8db6e..2143fb1452 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -701,7 +701,7 @@ std::string BackendBase::getReductionOperation(const std::string &reduction, con std::vector BackendBase::genInitReductionTargets(CodeStream &os, const CustomUpdateGroupMerged &cg, unsigned int batchSize, const std::string &idx) const { - return genInitReductionTargets( + return genInitReductionTargets( os, cg, batchSize, idx, [batchSize, &cg](const Models::VarReference &varRef, const std::string &index) { @@ -713,7 +713,7 @@ std::vector BackendBase::genInitReductionTargets(C std::vector BackendBase::genInitReductionTargets(CodeStream &os, const CustomUpdateWUGroupMerged &cg, unsigned int batchSize, const std::string &idx) const { - return genInitReductionTargets( + return genInitReductionTargets( os, cg, batchSize, idx, [batchSize, &cg](const Models::WUVarReference &varRef, const std::string &index) { diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index cb06968be0..5f08486742 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -210,7 +210,7 @@ size_t BackendSIMT::getPaddedNumCustomUpdateThreads(const CustomUpdateInternal & if (cg.isNeuronReduction()) { return padKernelSize(32 * numCopies, KernelCustomUpdate); } - else if (cg.getDims() & VarAccessDim::NEURON) { + else if (cg.getDims() & VarAccessDim::ELEMENT) { return numCopies * padKernelSize(cg.getSize(), KernelCustomUpdate); } else { @@ -514,7 +514,7 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM // Add population RNG field groupEnv.addField(getPopulationRNGType().createPointer(), "_rng", "rng", [this](const auto &g, size_t) { return getDeviceVarPrefix() + "rng" + g.getName(); }, - ng.getVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id)")); + ng.getVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, "$(id)")); // **TODO** for OCL do genPopulationRNGPreamble(os, popSubs, "group->rng[" + ng.getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "$(id)") + "]") in initialiser ng.generateNeuronUpdate(*this, groupEnv, batchSize, @@ -588,10 +588,10 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM // Copy spikes into block of $(_spk) const std::string queueOffset = ng.getWriteVarIndex(ng.getArchetype().isDelayRequired(), batchSize, - VarAccessDim::BATCH | VarAccessDim::NEURON, ""); + VarAccessDim::BATCH | VarAccessDim::ELEMENT, ""); if(!Utils::areTokensEmpty(ng.getArchetype().getThresholdConditionCodeTokens())) { const std::string queueOffsetTrueSpk = ng.getWriteVarIndex(ng.getArchetype().isTrueSpikeRequired() && ng.getArchetype().isDelayRequired(), - batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, ""); + batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, ""); groupEnv.print("if(" + getThreadID() + " < $(_sh_spk_count))"); { CodeStream::Scope b(groupEnv.getStream()); @@ -806,7 +806,7 @@ void BackendSIMT::genPostsynapticUpdateKernel(EnvironmentExternalBase &env, Mode { CodeStream::Scope b(groupEnv.getStream()); const std::string index = "(r * " + std::to_string(getKernelBlockSize(KernelPostsynapticUpdate)) + ") + " + getThreadID(); - groupEnv.printLine("const unsigned int spk = $(_trg_spk)[" + sg.getPostVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, index) + "];"); + groupEnv.printLine("const unsigned int spk = $(_trg_spk)[" + sg.getPostVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, index) + "];"); groupEnv.getStream() << "shSpk[" << getThreadID() << "] = spk;" << std::endl; if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { @@ -1042,8 +1042,8 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, ModelSpecMerge } } } - // Otherwise, if this update is per-neuron - else if (cg.getArchetype().getDims() & VarAccessDim::NEURON) { + // Otherwise, if this update is per-element + else if (cg.getArchetype().getDims() & VarAccessDim::ELEMENT) { if((cg.getArchetype().getDims() & VarAccessDim::BATCH) && (batchSize > 1)) { // Split ID into intra-batch ID and batch // **TODO** fast-divide style optimisations here diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index fcc0af2614..71cf73b381 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -85,7 +85,7 @@ std::string CustomUpdateGroupMerged::getVarIndex(unsigned int batchSize, VarAcce { // **YUCK** there's a lot of duplication in these methods - do they belong elsewhere? const bool batched = ((varDims & VarAccessDim::BATCH) && batchSize > 1); - if (!(varDims & VarAccessDim::NEURON)) { + if (!(varDims & VarAccessDim::ELEMENT)) { return batched ? "$(batch)" : "0"; } else if (batched) { @@ -103,7 +103,7 @@ std::string CustomUpdateGroupMerged::getVarRefIndex(bool delay, unsigned int bat // If delayed, variable is shared, the batch size is one or this custom update isn't batched, batch delay offset isn't required if(delay) { const bool batched = ((varDims & VarAccessDim::BATCH) && batchSize > 1); - if (!(varDims & VarAccessDim::NEURON)) { + if (!(varDims & VarAccessDim::ELEMENT)) { return batched ? "$(_batch_delay_slot)" : "$(_delay_slot)"; } else if (batched) { diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 453b5e5903..d8d1f8972d 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -34,7 +34,7 @@ size_t getNumVarCopies(VarAccessDim varDims, size_t batchSize, bool batched = tr //-------------------------------------------------------------------------- size_t getNumNeuronVarElements(VarAccessDim varDims, size_t numNeurons) { - return (varDims & VarAccessDim::NEURON) ? numNeurons : 1; + return (varDims & VarAccessDim::ELEMENT) ? numNeurons : 1; } //-------------------------------------------------------------------------- size_t getNeuronVarSize(VarAccessDim varDims, size_t numElements, size_t batchSize, @@ -46,26 +46,15 @@ size_t getNeuronVarSize(VarAccessDim varDims, size_t numElements, size_t batchSi size_t getSynapseVarSize(VarAccessDim varDims, const BackendBase &backend, const SynapseGroupInternal &sg, size_t batchSize, bool batched = true) { - const bool pre = (varDims & VarAccessDim::PRE_NEURON); - const bool post = (varDims & VarAccessDim::POST_NEURON); - const unsigned int numPre = sg.getSrcNeuronGroup()->getNumNeurons(); - const unsigned int numPost = sg.getTrgNeuronGroup()->getNumNeurons(); - const unsigned int rowStride = backend.getSynapticMatrixRowStride(sg); const size_t numCopies = getNumVarCopies(varDims, batchSize, batched); - if(pre && post) { + if(varDims & VarAccessDim::ELEMENT) { if(sg.getMatrixType() & SynapseMatrixWeight::KERNEL) { return sg.getKernelSizeFlattened() * numCopies; } else { - return numPre * rowStride * numCopies; + return sg.getSrcNeuronGroup()->getNumNeurons() * backend.getSynapticMatrixRowStride(sg) * numCopies; } } - else if(pre) { - return numPre * numCopies; - } - else if(post) { - return numPost * numCopies; - } else { return numCopies; } diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 9e4e231f00..f0ccc5be45 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -88,7 +88,7 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e // If variable has NEURON axis const VarAccessDim varDims = adaptor.getVarDims(var); - if (varDims & VarAccessDim::NEURON) { + if (varDims & VarAccessDim::ELEMENT) { backend.genVariableInit( varEnv, count, "id", [&adaptor, &fieldGroup, &fieldSuffix, &group, &var, &resolvedType, &varInit, batchSize, count, numDelaySlots, varDims] @@ -222,7 +222,7 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir [batchSize, this] (EnvironmentExternalBase &varEnv) { genVariableFill(varEnv, "_out_post", Type::writeNumeric(0.0, getScalarType()), - "id", "$(num_neurons)", VarAccessDim::BATCH | VarAccessDim::NEURON, batchSize); + "id", "$(num_neurons)", VarAccessDim::BATCH | VarAccessDim::ELEMENT, batchSize); }); @@ -235,7 +235,7 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir [batchSize, this](EnvironmentExternalBase &varEnv) { genVariableFill(varEnv, "_den_delay", Type::writeNumeric(0.0, getScalarType()), - "id", "$(num_neurons)", VarAccessDim::BATCH | VarAccessDim::NEURON, + "id", "$(num_neurons)", VarAccessDim::BATCH | VarAccessDim::ELEMENT, batchSize, true, getArchetype().getMaxDendriticDelayTimesteps()); }); @@ -269,7 +269,7 @@ void NeuronInitGroupMerged::OutSynPreOutput::generate(const BackendBase &backend [batchSize, this] (EnvironmentExternalBase &varEnv) { genVariableFill(varEnv, "_out_pre", Type::writeNumeric(0.0, getScalarType()), - "id", "$(num_neurons)", VarAccessDim::BATCH | VarAccessDim::NEURON, batchSize); + "id", "$(num_neurons)", VarAccessDim::BATCH | VarAccessDim::ELEMENT, batchSize); }); } @@ -450,7 +450,7 @@ void NeuronInitGroupMerged::genInitSpikeCount(const BackendBase &backend, Enviro (getArchetype().isTrueSpikeRequired() && getArchetype().isDelayRequired()); // Zero across all delay slots and batches - genScalarFill(spikeCountEnv, "_spk_cnt", "0", VarAccessDim::BATCH | VarAccessDim::NEURON, + genScalarFill(spikeCountEnv, "_spk_cnt", "0", VarAccessDim::BATCH | VarAccessDim::ELEMENT, batchSize, delayRequired, getArchetype().getNumDelaySlots()); }); } @@ -481,7 +481,7 @@ void NeuronInitGroupMerged::genInitSpikes(const BackendBase &backend, Environmen (getArchetype().isTrueSpikeRequired() && getArchetype().isDelayRequired()); // Zero across all delay slots and batches - genVariableFill(varEnv, "_spk", "0", "id", "$(num_neurons)", VarAccessDim::BATCH | VarAccessDim::NEURON, + genVariableFill(varEnv, "_spk", "0", "id", "$(num_neurons)", VarAccessDim::BATCH | VarAccessDim::ELEMENT, batchSize, delayRequired, getArchetype().getNumDelaySlots()); }); } @@ -500,7 +500,7 @@ void NeuronInitGroupMerged::genInitSpikeTime(const BackendBase &backend, Environ backend.genVariableInit(env, "num_neurons", "id", [batchSize, varName, this] (EnvironmentExternalBase &varEnv) { - genVariableFill(varEnv, varName, "-TIME_MAX", "id", "$(num_neurons)", VarAccessDim::BATCH | VarAccessDim::NEURON, + genVariableFill(varEnv, varName, "-TIME_MAX", "id", "$(num_neurons)", VarAccessDim::BATCH | VarAccessDim::ELEMENT, batchSize, getArchetype().isDelayRequired(), getArchetype().getNumDelaySlots()); }); } diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index f108f4bee5..0c810573cf 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -40,7 +40,7 @@ void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), csEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", - [batchSize, &ng](const std::string&, NeuronVarAccess d) + [batchSize, &ng](const std::string&, VarAccess d) { return ng.getVarIndex(batchSize, getVarAccessDim(d), "$(id)"); }); @@ -83,7 +83,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); // Read into local variable - const std::string idx = ng.getVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id)"); + const std::string idx = ng.getVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, "$(id)"); psmEnv.getStream() << "// postsynaptic model " << getIndex() << std::endl; psmEnv.printLine(getScalarType().getName() + " linSyn = $(_out_post)[" + idx + "];"); @@ -121,7 +121,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), psmEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", - [batchSize, &ng](const std::string&, NeuronVarAccess d) + [batchSize, &ng](const std::string&, VarAccess d) { return ng.getVarIndex(batchSize, getVarAccessDim(d), "$(id)"); }); @@ -134,7 +134,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env prettyPrintStatements(getArchetype().getPSDecayCodeTokens(), getTypeContext(), varEnv, decayErrorHandler); // Write back linSyn - varEnv.printLine("$(_out_post)[" + ng.getVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id)") + "] = linSyn;"); + varEnv.printLine("$(_out_post)[" + ng.getVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, "$(id)") + "] = linSyn;"); } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::InSynPSM::updateHash(boost::uuids::detail::sha1 &hash) const @@ -168,7 +168,7 @@ void NeuronUpdateGroupMerged::OutSynPreOutput::generate(const BackendBase &backe [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); // Add reverse insyn variable to - const std::string idx = ng.getVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id)"); + const std::string idx = ng.getVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, "$(id)"); outSynEnv.printLine(getArchetype().getPreTargetVar() + " += $(_out_pre)[" + idx + "];"); // Zero it again @@ -202,15 +202,15 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back const bool delayed = (getArchetype().getBackPropDelaySteps() != NO_DELAY); EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", - [batchSize, delayed, &synEnv, &ng](const std::string&, NeuronVarAccess d) + [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccess d) { return ng.getReadVarIndex(delayed, batchSize, getVarAccessDim(d), "$(id)"); }, - [batchSize, delayed, &synEnv, &ng](const std::string&, NeuronVarAccess d) + [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccess d) { return ng.getWriteVarIndex(delayed, batchSize, getVarAccessDim(d), "$(id)"); }, - [delayed](const std::string&, NeuronVarAccess) + [delayed](const std::string&, VarAccess) { return delayed; }); @@ -294,15 +294,15 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back const bool delayed = (getArchetype().getDelaySteps() != NO_DELAY); EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", - [batchSize, delayed, &ng](const std::string&, NeuronVarAccess d) + [batchSize, delayed, &ng](const std::string&, VarAccess d) { return ng.getReadVarIndex(delayed, batchSize, getVarAccessDim(d), "$(id)"); }, - [batchSize, delayed, &ng](const std::string&, NeuronVarAccess d) + [batchSize, delayed, &ng](const std::string&, VarAccess d) { return ng.getWriteVarIndex(delayed, batchSize, getVarAccessDim(d), "$(id)"); }, - [delayed](const std::string&, NeuronVarAccess) + [delayed](const std::string&, VarAccess) { return delayed; }); @@ -497,7 +497,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // Substitute spike times const std::string timePrecision = getTimeType().getName(); const std::string spikeTimeReadIndex = getReadVarIndex(getArchetype().isDelayRequired(), batchSize, - VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id)"); + VarAccessDim::BATCH | VarAccessDim::ELEMENT, "$(id)"); neuronEnv.add(getTimeType().addConst(), "st", "lsT", {neuronEnv.addInitialiser("const " + timePrecision + " lsT = $(_st)[" + spikeTimeReadIndex + "];")}); neuronEnv.add(getTimeType().addConst(), "prev_st", "lprevST", @@ -512,17 +512,17 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // **NOTE** always copy variables if variable is delayed EnvironmentLocalVarCache neuronVarEnv( *this, *this, getTypeContext(), neuronEnv, backend.getDeviceVarPrefix(), "", "l", - [batchSize, &neuronEnv, this](const std::string &varName, NeuronVarAccess d) + [batchSize, &neuronEnv, this](const std::string &varName, VarAccess d) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); return getReadVarIndex(delayed, batchSize, getVarAccessDim(d), "$(id)") ; }, - [batchSize, &neuronEnv, this](const std::string &varName, NeuronVarAccess d) + [batchSize, &neuronEnv, this](const std::string &varName, VarAccess d) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); return getWriteVarIndex(delayed, batchSize, getVarAccessDim(d), "$(id)") ; }, - [this](const std::string &varName, NeuronVarAccess) + [this](const std::string &varName, VarAccess) { return (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); }); @@ -704,12 +704,12 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // If spike times are required, copy times from register if(getArchetype().isSpikeTimeRequired()) { - neuronVarEnv.printLine("$(_st)[" + getWriteVarIndex(true, batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id)") + "] = $(st);"); + neuronVarEnv.printLine("$(_st)[" + getWriteVarIndex(true, batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, "$(id)") + "] = $(st);"); } // If previous spike times are required, copy times from register if(getArchetype().isPrevSpikeTimeRequired()) { - neuronVarEnv.printLine("$(_prev_st)[" + getWriteVarIndex(true, batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id)") + "] = $(prev_st);"); + neuronVarEnv.printLine("$(_prev_st)[" + getWriteVarIndex(true, batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, "$(id)") + "] = $(prev_st);"); } // Loop through outgoing synapse groups with some sort of presynaptic code @@ -746,7 +746,7 @@ std::string NeuronUpdateGroupMerged::getVarIndex(unsigned int batchSize, VarAcce { // **YUCK** there's a lot of duplication in these methods - do they belong elsewhere? const bool batched = ((varDims & VarAccessDim::BATCH) && batchSize > 1); - if (!(varDims & VarAccessDim::NEURON)) { + if (!(varDims & VarAccessDim::ELEMENT)) { return batched ? "$(batch)" : "0"; } else if(batched) { @@ -762,7 +762,7 @@ std::string NeuronUpdateGroupMerged::getReadVarIndex(bool delay, unsigned int ba { if(delay) { const bool batched = ((varDims & VarAccessDim::BATCH) && batchSize > 1); - if (!(varDims & VarAccessDim::NEURON)) { + if (!(varDims & VarAccessDim::ELEMENT)) { return batched ? "$(_read_batch_delay_slot)" : "$(_read_delay_slot)"; } else if(batched) { @@ -782,7 +782,7 @@ std::string NeuronUpdateGroupMerged::getWriteVarIndex(bool delay, unsigned int b { if(delay) { const bool batched = ((varDims & VarAccessDim::BATCH) && batchSize > 1); - if (!(varDims & VarAccessDim::NEURON)) { + if (!(varDims & VarAccessDim::ELEMENT)) { return batched ? "$(_write_batch_delay_slot)" : "$(_write_delay_slot)"; } else if (batched) { diff --git a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc index 6598e09847..c9fa7f90dc 100644 --- a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc +++ b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc @@ -96,7 +96,7 @@ void PreSpan::genUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMerg { CodeStream::Scope b(env.getStream()); - env.printLine("const unsigned int preInd = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "spike") + "];"); + env.printLine("const unsigned int preInd = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, "spike") + "];"); const auto indexType = backend.getSynapseIndexType(sg); const auto indexTypeName = indexType.getName(); @@ -247,7 +247,7 @@ void PostSpan::genUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMer { CodeStream::Scope b(env.getStream()); const std::string index = "(r * " + std::to_string(backend.getKernelBlockSize(KernelPresynapticUpdate)) + ") + " + backend.getThreadID(); - env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, index) + "];"); + env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, index) + "];"); env.printLine("$(_sh_spk" + eventSuffix + ")[" + backend.getThreadID() + "] = spk;"); if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { env.printLine("$(_sh_row_length)[" + backend.getThreadID() + "] = $(_row_length)[spk];"); @@ -459,7 +459,7 @@ void PreSpanProcedural::genUpdate(EnvironmentExternalBase &env, PresynapticUpdat // Create environment and add presynaptic index EnvironmentGroupMergedField synEnv(groupEnv, sg); synEnv.add(Type::Uint32.addConst(), "id_pre", "preInd", - {synEnv.addInitialiser("const unsigned int preInd = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(_spike)") + "];")}); + {synEnv.addInitialiser("const unsigned int preInd = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, "$(_spike)") + "];")}); // **YUCK** add a hidden copy of num_post so we can overwrite deeper in here without losing access to original synEnv.add(Type::Uint32.addConst(), "_num_post", "$(num_post)"); @@ -639,7 +639,7 @@ void PostSpanBitmask::genUpdate(EnvironmentExternalBase &env, PresynapticUpdateG { CodeStream::Scope b(env.getStream()); const std::string index = "(r * " + std::to_string(backend.getKernelBlockSize(KernelPresynapticUpdate)) + ") + " + backend.getThreadID(); - env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, index) + "];"); + env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, index) + "];"); env.printLine("$(_sh_spk" + eventSuffix + ")[" + backend.getThreadID() + "] = spk;"); } backend.genSharedMemBarrier(env.getStream()); @@ -873,7 +873,7 @@ void PostSpanToeplitz::genUpdate(EnvironmentExternalBase &env, PresynapticUpdate { CodeStream::Scope b(env.getStream()); const std::string index = "(r * " + std::to_string(backend.getKernelBlockSize(KernelPresynapticUpdate)) + ") + " + backend.getThreadID(); - env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, index) + "];"); + env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, index) + "];"); env.printLine("$(_sh_spk" + eventSuffix + ")[" + backend.getThreadID() + "] = spk;"); } backend.genSharedMemBarrier(env.getStream()); diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index e70f102f0c..aafebb7e95 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -30,13 +30,13 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa // Substitute names of pre and postsynaptic weight update variable synEnv.template addVars( backend.getDeviceVarPrefix(), - [&sg, batchSize](NeuronVarAccess a, const std::string&) + [&sg, batchSize](VarAccess a, const std::string&) { return sg.getPreWUVarIndex(batchSize, getVarAccessDim(a), "$(id_pre)"); }, "", true); synEnv.template addVars( backend.getDeviceVarPrefix(), - [&sg, batchSize](NeuronVarAccess a, const std::string&) + [&sg, batchSize](VarAccess a, const std::string&) { return sg.getPostWUVarIndex(batchSize, getVarAccessDim(a), "$(id_post)"); }, "", true); @@ -53,8 +53,8 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa const std::string timeStr = sg.getTimeType().getName(); const std::string axonalDelayMs = Type::writeNumeric(dt * (double)(sg.getArchetype().getDelaySteps() + 1u), sg.getTimeType()); const bool preDelay = sg.getArchetype().getSrcNeuronGroup()->isDelayRequired(); - const std::string preSTIndex = sg.getPreVarIndex(preDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id_pre)"); - const std::string prevPreSTIndex = sg.getPrePrevSpikeTimeIndex(preDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id_pre)"); + const std::string preSTIndex = sg.getPreVarIndex(preDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, "$(id_pre)"); + const std::string prevPreSTIndex = sg.getPrePrevSpikeTimeIndex(preDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, "$(id_pre)"); synEnv.add(sg.getTimeType().addConst(), "st_pre", "stPre", {synEnv.addInitialiser("const " + timeStr + " stPre = " + axonalDelayMs + " + $(_src_st)[" + preSTIndex + "];")}); synEnv.add(sg.getTimeType().addConst(), "prev_st_pre", "prevSTPre", @@ -67,8 +67,8 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa // Calculate backprop delay to add to (somatic) spike times and substitute in postsynaptic spike times const std::string backPropDelayMs = Type::writeNumeric(dt * (double)(sg.getArchetype().getBackPropDelaySteps() + 1u), sg.getTimeType()); const bool postDelay = sg.getArchetype().getTrgNeuronGroup()->isDelayRequired(); - const std::string postSTIndex = sg.getPostVarIndex(postDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id_post)"); - const std::string prevPostSTIndex = sg.getPostPrevSpikeTimeIndex(postDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::NEURON, "$(id_post)"); + const std::string postSTIndex = sg.getPostVarIndex(postDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, "$(id_post)"); + const std::string prevPostSTIndex = sg.getPostPrevSpikeTimeIndex(postDelay, batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, "$(id_post)"); synEnv.add(sg.getTimeType().addConst(), "st_post", "stPost", {synEnv.addInitialiser("const " + timeStr + " stPost = " + backPropDelayMs + " + $(_trg_st)[" + postSTIndex + "];")}); synEnv.add(sg.getTimeType().addConst(), "prev_st_post", "prevSTPost", @@ -78,7 +78,7 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa if (sg.getArchetype().getMatrixType() & SynapseMatrixWeight::INDIVIDUAL) { synEnv.template addVars( backend.getDeviceVarPrefix(), - [&sg, batchSize](SynapseVarAccess a, const std::string&) + [&sg, batchSize](VarAccess a, const std::string&) { return sg.getSynVarIndex(batchSize, getVarAccessDim(a), "$(id_syn)"); }); @@ -121,7 +121,7 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa synEnv.template addVars( backend.getDeviceVarPrefix(), - [&sg, batchSize](SynapseVarAccess a, const std::string&) + [&sg, batchSize](VarAccess a, const std::string&) { return sg.getKernelVarIndex(batchSize, getVarAccessDim(a), "$(id_kernel)"); }); @@ -252,7 +252,7 @@ std::string SynapseGroupMergedBase::getPrePostVarIndex(bool delay, unsigned int { const bool batched = ((varDims & VarAccessDim::BATCH) && batchSize > 1); if (delay) { - if (!(varDims & VarAccessDim::NEURON)) { + if (!(varDims & VarAccessDim::ELEMENT)) { return (batched ? "$(_" + prefix + "_batch_delay_slot)" : "$(_" + prefix + "_delay_slot)"); } else if(batched) { @@ -263,7 +263,7 @@ std::string SynapseGroupMergedBase::getPrePostVarIndex(bool delay, unsigned int } } else { - if (!(varDims & VarAccessDim::NEURON)) { + if (!(varDims & VarAccessDim::ELEMENT)) { return batched ? "$(batch)" : "0"; } else if (batched) { diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index 6b20820c3a..c8ea038aeb 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -456,14 +456,7 @@ EGPReference EGPReference::createWUEGPRef(const SynapseGroup *sg, const std::str //---------------------------------------------------------------------------- // Free functions //---------------------------------------------------------------------------- -void updateHash(const Base::NeuronVar &v, boost::uuids::detail::sha1 &hash) -{ - Utils::updateHash(v.name, hash); - Type::updateHash(v.type, hash); - Utils::updateHash(v.access, hash); -} -//---------------------------------------------------------------------------- -void updateHash(const Base::SynapseVar &v, boost::uuids::detail::sha1 &hash) +void updateHash(const Base::Var &v, boost::uuids::detail::sha1 &hash) { Utils::updateHash(v.name, hash); Type::updateHash(v.type, hash); diff --git a/tests/unit/customConnectivityUpdate.cc b/tests/unit/customConnectivityUpdate.cc index ac65ba2fe6..ecb8296f43 100644 --- a/tests/unit/customConnectivityUpdate.cc +++ b/tests/unit/customConnectivityUpdate.cc @@ -19,7 +19,7 @@ class StaticPulseDendriticDelayReverse : public WeightUpdateModels::Base public: DECLARE_SNIPPET(StaticPulseDendriticDelayReverse); - SET_SYNAPSE_VARS({{"d", "uint8_t", SynapseVarAccess::READ_ONLY}, {"g", "scalar", SynapseVarAccess::READ_ONLY}}); + SET_VARS({{"d", "uint8_t", VarAccess::READ_ONLY}, {"g", "scalar", VarAccess::READ_ONLY}}); SET_SIM_CODE("addToPostDelay(g, d);\n"); }; @@ -41,7 +41,7 @@ class RemoveSynapse : public CustomConnectivityUpdateModels::Base public: DECLARE_SNIPPET(RemoveSynapse); - SET_SYNAPSE_VARS({{"a", "scalar"}}); + SET_VARS({{"a", "scalar"}}); SET_ROW_UPDATE_CODE( "for_each_synapse {\n" " if(id_post == (id_pre + 1)) {\n" @@ -57,7 +57,7 @@ class RemoveSynapseVarRef : public CustomConnectivityUpdateModels::Base public: DECLARE_SNIPPET(RemoveSynapseVarRef); - SET_SYNAPSE_VARS({{"a", "scalar"}}); + SET_VARS({{"a", "scalar"}}); SET_VAR_REFS({{"b", "scalar"}}); SET_ROW_UPDATE_CODE( "for_each_synapse {\n" @@ -108,7 +108,7 @@ class Cont : public WeightUpdateModels::Base public: DECLARE_SNIPPET(Cont); - SET_SYNAPSE_VARS({{"g", "scalar"}}); + SET_VARS({{"g", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE( "addToPost(g * V_pre);\n"); @@ -120,7 +120,7 @@ class ContPost : public WeightUpdateModels::Base public: DECLARE_SNIPPET(ContPost); - SET_SYNAPSE_VARS({{"g", "scalar"}}); + SET_VARS({{"g", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE( "addToPost(g * V_post);\n"); diff --git a/tests/unit/customUpdate.cc b/tests/unit/customUpdate.cc index 3410f08d82..0f1f4bb2ed 100644 --- a/tests/unit/customUpdate.cc +++ b/tests/unit/customUpdate.cc @@ -23,9 +23,9 @@ class IzhikevichVariableShared : public NeuronModels::Izhikevich DECLARE_SNIPPET(IzhikevichVariableShared); SET_PARAM_NAMES({}); - SET_NEURON_VARS({{"V","scalar"}, {"U", "scalar"}, - {"a", "scalar", NeuronVarAccess::READ_ONLY_SHARED_NEURON}, {"b", "scalar", NeuronVarAccess::READ_ONLY_SHARED_NEURON}, - {"c", "scalar", NeuronVarAccess::READ_ONLY_SHARED_NEURON}, {"d", "scalar", NeuronVarAccess::READ_ONLY_SHARED_NEURON}}); + SET_VARS({{"V","scalar"}, {"U", "scalar"}, + {"a", "scalar", VarAccess::READ_ONLY_SHARED_NEURON}, {"b", "scalar", VarAccess::READ_ONLY_SHARED_NEURON}, + {"c", "scalar", VarAccess::READ_ONLY_SHARED_NEURON}, {"d", "scalar", VarAccess::READ_ONLY_SHARED_NEURON}}); }; IMPLEMENT_SNIPPET(IzhikevichVariableShared); @@ -34,10 +34,10 @@ class StaticPulseDendriticDelaySplit : public WeightUpdateModels::Base public: DECLARE_SNIPPET(StaticPulseDendriticDelaySplit); - SET_SYNAPSE_VARS({{"gCommon", "scalar", SynapseVarAccess::READ_ONLY}, - {"g", "scalar", SynapseVarAccess::READ_ONLY_DUPLICATE}, - {"dCommon", "scalar", SynapseVarAccess::READ_ONLY}, - {"d", "scalar", SynapseVarAccess::READ_ONLY_DUPLICATE}}); + SET_VARS({{"gCommon", "scalar", VarAccess::READ_ONLY}, + {"g", "scalar", VarAccess::READ_ONLY_DUPLICATE}, + {"dCommon", "scalar", VarAccess::READ_ONLY}, + {"d", "scalar", VarAccess::READ_ONLY_DUPLICATE}}); SET_SIM_CODE("$(addToInSynDelay, $(gCommon) + $(g), $(dCommon) + $(d));\n"); }; @@ -73,7 +73,7 @@ class Sum3 : public CustomUpdateModels::Base SET_UPDATE_CODE("$(sum) = $(scale) * ($(a) + $(b));\n"); - SET_CUSTOM_UPDATE_VARS({{"sum", "scalar"}, {"scale", "scalar", CustomUpdateVarAccess::READ_ONLY_SHARED_NEURON}}); + SET_CUSTOM_UPDATE_VARS({{"sum", "scalar"}, {"scale", "scalar", CustomUpdateVarAccess::READ_ONLY_SHARED_ELEMENT}}); SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_WRITE}, {"b", "scalar", VarAccessMode::READ_ONLY}}); }; @@ -131,7 +131,7 @@ class Cont : public WeightUpdateModels::Base public: DECLARE_SNIPPET(Cont); - SET_SYNAPSE_VARS({{"g", "scalar"}}); + SET_VARS({{"g", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE( "addToPost(g * V_pre);\n"); @@ -143,7 +143,7 @@ class Cont2 : public WeightUpdateModels::Base public: DECLARE_SNIPPET(Cont2); - SET_SYNAPSE_VARS({{"g", "scalar"}, {"x", "scalar"}}); + SET_VARS({{"g", "scalar"}, {"x", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE( "addToPost((g + x) * V_pre);\n"); @@ -170,7 +170,7 @@ class ReduceDouble : public CustomUpdateModels::Base "reduction2 = var2;\n"); SET_CUSTOM_UPDATE_VARS({{"reduction1", "scalar", CustomUpdateVarAccess::REDUCE_BATCH_SUM}, - {"reduction2", "scalar", CustomUpdateVarAccess::REDUCE_NEURON_SUM}}); + {"reduction2", "scalar", CustomUpdateVarAccess::REDUCE_ELEMENT_SUM}}); SET_VAR_REFS({{"var1", "scalar", VarAccessMode::READ_ONLY}, {"var2", "scalar", VarAccessMode::READ_ONLY}}); @@ -195,7 +195,7 @@ class ReduceNeuronSharedVar : public CustomUpdateModels::Base SET_UPDATE_CODE("reduction = var;\n"); - SET_CUSTOM_UPDATE_VARS({{"reduction", "scalar", CustomUpdateVarAccess::REDUCE_NEURON_SUM}}) + SET_CUSTOM_UPDATE_VARS({{"reduction", "scalar", CustomUpdateVarAccess::REDUCE_ELEMENT_SUM}}) SET_VAR_REFS({{"var", "scalar", VarAccessMode::READ_ONLY}}); }; IMPLEMENT_SNIPPET(ReduceNeuronSharedVar); @@ -510,13 +510,13 @@ TEST(CustomUpdates, BatchingVars) model.finalise(); EXPECT_TRUE(static_cast(sum1)->getDims() & VarAccessDim::BATCH); - EXPECT_TRUE(static_cast(sum1)->getDims() & VarAccessDim::NEURON); + EXPECT_TRUE(static_cast(sum1)->getDims() & VarAccessDim::ELEMENT); EXPECT_FALSE(static_cast(sum2)->getDims() & VarAccessDim::BATCH); - EXPECT_TRUE(static_cast(sum2)->getDims() & VarAccessDim::NEURON); + EXPECT_TRUE(static_cast(sum2)->getDims() & VarAccessDim::ELEMENT); EXPECT_TRUE(static_cast(sum3)->getDims() & VarAccessDim::BATCH); - EXPECT_TRUE(static_cast(sum3)->getDims() & VarAccessDim::NEURON); + EXPECT_TRUE(static_cast(sum3)->getDims() & VarAccessDim::ELEMENT); EXPECT_FALSE(static_cast(sum4)->getDims() & VarAccessDim::BATCH); - EXPECT_TRUE(static_cast(sum4)->getDims() & VarAccessDim::NEURON); + EXPECT_TRUE(static_cast(sum4)->getDims() & VarAccessDim::ELEMENT); } //-------------------------------------------------------------------------- TEST(CustomUpdates, NeuronSharedVars) @@ -535,7 +535,7 @@ TEST(CustomUpdates, NeuronSharedVars) auto *cuInternal = static_cast(cu); EXPECT_TRUE(cuInternal->getDims() & VarAccessDim::BATCH); - EXPECT_FALSE(cuInternal->getDims() & VarAccessDim::NEURON); + EXPECT_FALSE(cuInternal->getDims() & VarAccessDim::ELEMENT); } //-------------------------------------------------------------------------- TEST(CustomUpdates, BatchingWriteShared) @@ -621,7 +621,7 @@ TEST(CustomUpdates, ReductionTypeDuplicateNeuron) ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::BATCH); ASSERT_FALSE(cuInternal->isBatchReduction()); ASSERT_TRUE(cuInternal->isNeuronReduction()); - ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::NEURON); + ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::ELEMENT); } //-------------------------------------------------------------------------- TEST(CustomUpdates, ReductionTypeDuplicateNeuronInternal) @@ -644,7 +644,7 @@ TEST(CustomUpdates, ReductionTypeDuplicateNeuronInternal) ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::BATCH); ASSERT_FALSE(cuInternal->isBatchReduction()); ASSERT_TRUE(cuInternal->isNeuronReduction()); - ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::NEURON); + ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::ELEMENT); } //-------------------------------------------------------------------------- TEST(CustomUpdates, ReductionTypeSharedNeuronInternal) @@ -667,7 +667,7 @@ TEST(CustomUpdates, ReductionTypeSharedNeuronInternal) ASSERT_FALSE(cuInternal->getDims() & VarAccessDim::BATCH); ASSERT_FALSE(cuInternal->isBatchReduction()); ASSERT_TRUE(cuInternal->isNeuronReduction()); - ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::NEURON); + ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::ELEMENT); } //-------------------------------------------------------------------------- TEST(CustomUpdates, ReductionTypeDuplicateBatch) @@ -689,7 +689,7 @@ TEST(CustomUpdates, ReductionTypeDuplicateBatch) ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::BATCH); ASSERT_TRUE(cuInternal->isBatchReduction()); ASSERT_FALSE(cuInternal->isNeuronReduction()); - ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::NEURON); + ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::ELEMENT); } //-------------------------------------------------------------------------- TEST(CustomUpdates, ReductionTypeDuplicateBatchInternal) @@ -712,7 +712,7 @@ TEST(CustomUpdates, ReductionTypeDuplicateBatchInternal) ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::BATCH); ASSERT_TRUE(cuInternal->isBatchReduction()); ASSERT_FALSE(cuInternal->isNeuronReduction()); - ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::NEURON); + ASSERT_TRUE(cuInternal->getDims() & VarAccessDim::ELEMENT); } //-------------------------------------------------------------------------- TEST(CustomUpdates, NeuronSharedCustomUpdateWU) diff --git a/tests/unit/modelSpec.cc b/tests/unit/modelSpec.cc index d89f242661..ef208fc6a3 100644 --- a/tests/unit/modelSpec.cc +++ b/tests/unit/modelSpec.cc @@ -24,7 +24,7 @@ class AlphaCurr : public PostsynapticModels::Base SET_PARAM_NAMES({"tau"}); - SET_NEURON_VARS({{"x", "scalar"}}); + SET_VARS({{"x", "scalar"}}); SET_DERIVED_PARAMS({ {"expDecay", [](const ParamValues &pars, double dt) { return std::exp(-dt / pars.at("tau")); }}, @@ -50,7 +50,7 @@ class RemoveSynapse : public CustomConnectivityUpdateModels::Base public: DECLARE_SNIPPET(RemoveSynapse); - SET_SYNAPSE_VARS({{"a", "scalar"}}); + SET_VARS({{"a", "scalar"}}); SET_ROW_UPDATE_CODE( "for_each_synapse{\n" " if(id_post == (id_pre + 1)) {\n" diff --git a/tests/unit/modelSpecMerged.cc b/tests/unit/modelSpecMerged.cc index 2f0651990b..d28b586e64 100644 --- a/tests/unit/modelSpecMerged.cc +++ b/tests/unit/modelSpecMerged.cc @@ -32,7 +32,7 @@ class AlphaCurr : public PostsynapticModels::Base SET_PARAM_NAMES({"tau"}); - SET_NEURON_VARS({{"x", "scalar"}}); + SET_VARS({{"x", "scalar"}}); SET_DERIVED_PARAMS({ {"expDecay", [](const ParamValues &pars, double dt) { return std::exp(-dt / pars.at("tau")); }}, @@ -53,7 +53,7 @@ class STDPAdditive : public WeightUpdateModels::Base "Wmin", // 4 - Minimum weight "Wmax"}); // 5 - Maximum weight - SET_SYNAPSE_VARS({{"g", "scalar"}}); + SET_VARS({{"g", "scalar"}}); SET_PRE_VARS({{"preTrace", "scalar"}}); SET_POST_VARS({{"postTrace", "scalar"}}); @@ -127,7 +127,7 @@ class RemoveSynapsePrePost : public CustomConnectivityUpdateModels::Base public: DECLARE_SNIPPET(RemoveSynapsePrePost); - SET_SYNAPSE_VARS({{"g", "scalar"}}); + SET_VARS({{"g", "scalar"}}); SET_PRE_VARS({{"preThresh", "scalar"}}); SET_POST_VARS({{"postThresh", "scalar"}}); SET_ROW_UPDATE_CODE( diff --git a/tests/unit/models.cc b/tests/unit/models.cc index 972985166d..18335c61a2 100644 --- a/tests/unit/models.cc +++ b/tests/unit/models.cc @@ -24,7 +24,7 @@ class AlphaCurr : public PostsynapticModels::Base SET_PARAM_NAMES({"tau"}); - SET_NEURON_VARS({{"x", "scalar"}}); + SET_VARS({{"x", "scalar"}}); SET_DERIVED_PARAMS({ {"expDecay", [](const ParamValues &pars, double dt) { return std::exp(-dt / pars.at("tau")); }}, @@ -48,7 +48,7 @@ class Cont : public WeightUpdateModels::Base public: DECLARE_SNIPPET(Cont); - SET_SYNAPSE_VARS({{"g", "scalar"}}); + SET_VARS({{"g", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE( "addToPost(g * V_pre);\n"); @@ -60,7 +60,7 @@ class ContPrePost : public WeightUpdateModels::Base public: DECLARE_SNIPPET(ContPrePost); - SET_SYNAPSE_VARS({{"g", "scalar"}}); + SET_VARS({{"g", "scalar"}}); SET_PRE_VARS({{"preTrace", "scalar"}}); SET_POST_VARS({{"postTrace", "scalar"}}); diff --git a/tests/unit/neuronGroup.cc b/tests/unit/neuronGroup.cc index 68e8bc6bcc..dc9f4c2c9c 100644 --- a/tests/unit/neuronGroup.cc +++ b/tests/unit/neuronGroup.cc @@ -19,7 +19,7 @@ class StaticPulseBack : public WeightUpdateModels::Base public: DECLARE_SNIPPET(StaticPulseBack); - SET_SYNAPSE_VARS({{"g", "scalar", SynapseVarAccess::READ_ONLY}}); + SET_VARS({{"g", "scalar", VarAccess::READ_ONLY}}); SET_SIM_CODE( "$(addToInSyn, $(g));\n" @@ -79,7 +79,7 @@ class AlphaCurr : public PostsynapticModels::Base SET_PARAM_NAMES({"tau"}); - SET_NEURON_VARS({{"x", "scalar"}}); + SET_VARS({{"x", "scalar"}}); SET_DERIVED_PARAMS({ {"expDecay", [](const ParamValues &pars, double dt) { return std::exp(-dt / pars.at("tau")); }}, @@ -122,7 +122,7 @@ class LIFAdditional : public NeuronModels::Base {"ExpTC", [](const ParamValues &pars, double dt) { return std::exp(-dt / pars.at("TauM")); }}, {"Rmembrane", [](const ParamValues &pars, double) { return pars.at("TauM") / pars.at("C"); }}}); - SET_NEURON_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); + SET_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); }; IMPLEMENT_SNIPPET(LIFAdditional); @@ -163,7 +163,7 @@ class LIFRandom : public NeuronModels::Base {"ExpTC", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("TauM")); }}, {"Rmembrane", [](const ParamValues &pars, double){ return pars.at("TauM") / pars.at("C"); }}}); - SET_NEURON_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); + SET_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); }; IMPLEMENT_SNIPPET(LIFRandom); @@ -176,7 +176,7 @@ class STDPAdditive : public WeightUpdateModels::Base SET_DERIVED_PARAMS({ {"tauPlusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauPlus")); }}, {"tauMinusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauMinus")); }}}); - SET_SYNAPSE_VARS({{"g", "scalar"}}); + SET_VARS({{"g", "scalar"}}); SET_PRE_VARS({{"preTrace", "scalar"}}); SET_POST_VARS({{"postTrace", "scalar"}}); diff --git a/tests/unit/neuronModels.cc b/tests/unit/neuronModels.cc index 30eb0ba3b1..fc2c244a23 100644 --- a/tests/unit/neuronModels.cc +++ b/tests/unit/neuronModels.cc @@ -42,7 +42,7 @@ class LIFCopy : public NeuronModels::Base {"ExpTC", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("TauM")); }}, {"Rmembrane", [](const ParamValues &pars, double){ return pars.at("TauM") / pars.at("C"); }}}); - SET_NEURON_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); + SET_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); SET_NEEDS_AUTO_REFRACTORY(false); }; diff --git a/tests/unit/synapseGroup.cc b/tests/unit/synapseGroup.cc index d45a1d1a1b..39bfac2c66 100644 --- a/tests/unit/synapseGroup.cc +++ b/tests/unit/synapseGroup.cc @@ -26,7 +26,7 @@ class STDPAdditive : public WeightUpdateModels::Base SET_DERIVED_PARAMS({ {"tauPlusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauPlus")); }}, {"tauMinusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauMinus")); }}}); - SET_SYNAPSE_VARS({{"g", "scalar"}}); + SET_VARS({{"g", "scalar"}}); SET_PRE_VARS({{"preTrace", "scalar"}}); SET_POST_VARS({{"postTrace", "scalar"}}); @@ -58,7 +58,7 @@ class STDPAdditiveEGPWMinMax : public WeightUpdateModels::Base SET_DERIVED_PARAMS({ {"tauPlusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauPlus")); }}, {"tauMinusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauMinus")); }}}); - SET_SYNAPSE_VARS({{"g", "scalar"}}); + SET_VARS({{"g", "scalar"}}); SET_PRE_VARS({{"preTrace", "scalar"}}); SET_POST_VARS({{"postTrace", "scalar"}}); SET_EXTRA_GLOBAL_PARAMS({{"Wmin", "scalar"}, {"Wmax", "scalar"}}); @@ -92,7 +92,7 @@ class STDPAdditiveEGPSpike : public WeightUpdateModels::Base SET_DERIVED_PARAMS({ {"tauPlusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauPlus")); }}, {"tauMinusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauMinus")); }}}); - SET_SYNAPSE_VARS({{"g", "scalar"}}); + SET_VARS({{"g", "scalar"}}); SET_PRE_VARS({{"preTrace", "scalar"}}); SET_POST_VARS({{"postTrace", "scalar"}}); SET_EXTRA_GLOBAL_PARAMS({{"S", "scalar"}}); @@ -122,7 +122,7 @@ class STDPAdditiveEGPDynamics : public WeightUpdateModels::Base public: DECLARE_SNIPPET(STDPAdditiveEGPDynamics); SET_PARAM_NAMES({"Aplus", "Aminus", "Wmin", "Wmax"}); - SET_SYNAPSE_VARS({{"g", "scalar"}}); + SET_VARS({{"g", "scalar"}}); SET_PRE_VARS({{"preTrace", "scalar"}}); SET_POST_VARS({{"postTrace", "scalar"}}); SET_EXTRA_GLOBAL_PARAMS({{"tauPlusDecay", "scalar"}, {"tauMinusDecay", "scalar"}}); @@ -152,7 +152,7 @@ class Continuous : public WeightUpdateModels::Base public: DECLARE_SNIPPET(Continuous); - SET_SYNAPSE_VARS({{"g", "scalar"}}); + SET_VARS({{"g", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE("addToPost(g * V_pre);\n"); }; @@ -196,7 +196,7 @@ class StaticPulseDynamics : public WeightUpdateModels::Base public: DECLARE_SNIPPET(StaticPulseDynamics); - SET_SYNAPSE_VARS({ {"g", "scalar"} }); + SET_VARS({ {"g", "scalar"} }); SET_SIM_CODE("addToPost(g);\n"); SET_SYNAPSE_DYNAMICS_CODE("g *= 0.99;\n"); @@ -208,7 +208,7 @@ class StaticPulsePostLearn : public WeightUpdateModels::Base public: DECLARE_SNIPPET(StaticPulsePostLearn); - SET_SYNAPSE_VARS({ {"g", "scalar"} }); + SET_VARS({ {"g", "scalar"} }); SET_SIM_CODE("addToPost(g);\n"); SET_LEARN_POST_CODE("g *= 0.99;\n"); @@ -284,7 +284,7 @@ class LIFAdditional : public NeuronModels::Base {"ExpTC", [](const ParamValues &pars, double dt) { return std::exp(-dt / pars.at("TauM")); }}, {"Rmembrane", [](const ParamValues &pars, double) { return pars.at("TauM") / pars.at("C"); }}}); - SET_NEURON_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); + SET_VARS({{"V", "scalar"}, {"RefracTime", "scalar"}}); }; IMPLEMENT_SNIPPET(LIFAdditional); } // Anonymous namespace diff --git a/tests/unit/weightUpdateModels.cc b/tests/unit/weightUpdateModels.cc index dfef43267e..ab76154296 100644 --- a/tests/unit/weightUpdateModels.cc +++ b/tests/unit/weightUpdateModels.cc @@ -17,7 +17,7 @@ class PiecewiseSTDPCopy : public WeightUpdateModels::Base public: SET_PARAM_NAMES({"tLrn", "tChng", "tDecay", "tPunish10", "tPunish01", "gMax", "gMid", "gSlope", "tauShift", "gSyn0"}); - SET_SYNAPSE_VARS({{"g", "scalar"}, {"gRaw", "scalar"}}); + SET_VARS({{"g", "scalar"}, {"gRaw", "scalar"}}); SET_SIM_CODE( "addToPost(g);\n" @@ -65,7 +65,7 @@ class STDPAdditive : public WeightUpdateModels::Base SET_DERIVED_PARAMS({ {"tauPlusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauPlus")); }}, {"tauMinusDecay", [](const ParamValues &pars, double dt){ return std::exp(-dt / pars.at("tauMinus")); }}}); - SET_SYNAPSE_VARS({{"g", "scalar"}}); + SET_VARS({{"g", "scalar"}}); SET_PRE_VARS({{"preTrace", "scalar"}}); SET_POST_VARS({{"postTrace", "scalar"}}); From 14dca3ac5469a5da29442b55a3433f5d01f8594c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 31 Oct 2023 11:32:39 +0000 Subject: [PATCH 59/60] update PygeNn to reflect simplified design --- include/genn/genn/varAccess.h | 6 +- pygenn/__init__.py | 5 +- pygenn/genn_groups.py | 14 +--- pygenn/genn_model.py | 24 +++---- pygenn/src/genn.cc | 72 +++++++------------ .../test_custom_connectivity_update.py | 6 +- tests/features/test_custom_update.py | 39 +++++----- tests/features/test_spike_propagation.py | 8 +-- tests/features/test_wu_vars.py | 10 +-- 9 files changed, 75 insertions(+), 109 deletions(-) diff --git a/include/genn/genn/varAccess.h b/include/genn/genn/varAccess.h index b90dab8494..895ce62075 100644 --- a/include/genn/genn/varAccess.h +++ b/include/genn/genn/varAccess.h @@ -60,13 +60,13 @@ enum class CustomUpdateVarAccess : unsigned int READ_ONLY_SHARED = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::BATCH), // Variables which will be shared across neurons if per-element - READ_ONLY_SHARED_ELEMENT = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::ELEMENT), + READ_ONLY_SHARED_NEURON = static_cast(VarAccessMode::READ_ONLY) | static_cast(VarAccessDim::ELEMENT), // Reduction variables REDUCE_BATCH_SUM = static_cast(VarAccessMode::REDUCE_SUM) | static_cast(VarAccessDim::BATCH), REDUCE_BATCH_MAX = static_cast(VarAccessMode::REDUCE_MAX) | static_cast(VarAccessDim::BATCH), - REDUCE_ELEMENT_SUM = static_cast(VarAccessMode::REDUCE_SUM) | static_cast(VarAccessDim::ELEMENT), - REDUCE_ELEMENT_MAX = static_cast(VarAccessMode::REDUCE_MAX) | static_cast(VarAccessDim::ELEMENT), + REDUCE_NEURON_SUM = static_cast(VarAccessMode::REDUCE_SUM) | static_cast(VarAccessDim::ELEMENT), + REDUCE_NEURON_MAX = static_cast(VarAccessMode::REDUCE_MAX) | static_cast(VarAccessDim::ELEMENT), }; //---------------------------------------------------------------------------- diff --git a/pygenn/__init__.py b/pygenn/__init__.py index 9653790ef6..d2620d762a 100644 --- a/pygenn/__init__.py +++ b/pygenn/__init__.py @@ -5,9 +5,8 @@ from .genn import (create_var_ref, create_psm_var_ref, create_wu_pre_var_ref, create_wu_post_var_ref, create_wu_var_ref, create_egp_ref, create_psm_egp_ref, create_wu_egp_ref, - CustomUpdateVarAccess, NeuronVarAccess, PlogSeverity, - SpanType, SynapseMatrixType, SynapseVarAccess, - VarAccessMode, VarLocation) + CustomUpdateVarAccess, PlogSeverity, SpanType, + SynapseMatrixType, VarAccess, VarAccessMode, VarLocation) from .genn_model import (GeNNModel, create_neuron_model, create_postsynaptic_model, create_weight_update_model, diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index a69044e128..d245fa1bfa 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -26,7 +26,7 @@ def _get_num_var_copies(var_dims, batch_size): return () def _get_num_neuron_var_elements(var_dims, num_elements): - if (var_dims & VarAccessDim.NEURON): + if (var_dims & VarAccessDim.ELEMENT): return (num_elements,) else: return (1,) @@ -40,20 +40,12 @@ def _get_neuron_var_shape(var_dims, num_elements, batch_size, def _get_synapse_var_shape(var_dims, sg, batch_size): num_copies = _get_num_var_copies(var_dims, batch_size) - pre = (var_dims & VarAccessDim.PRE_NEURON) - post = (var_dims & VarAccessDim.POST_NEURON) - num_pre = sg.src.size - num_post = sg.trg.size - if pre and post: + if (var_dims & VarAccessDim.ELEMENT): if sg.matrix_type & SynapseMatrixWeight.KERNEL: return num_copies + (np.product(sg.kernel_size),) else: # **YUCK** this isn't correct - only backend knows correct stride - return num_copies + (num_pre * sg.max_connections,) - elif pre: - return num_copies + (num_pre,) - elif post: - return num_copies + (num_post,) + return num_copies + (sg.src.size * sg.max_connections,) else: return num_copies + (1,) diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py index 1ccda3fff6..a97ff99eb7 100644 --- a/pygenn/genn_model.py +++ b/pygenn/genn_model.py @@ -65,10 +65,10 @@ DerivedParam, EGP, EGPRef, InitSparseConnectivitySnippetBase, InitToeplitzConnectivitySnippetBase, InitVarSnippetBase, - ModelSpecInternal, NeuronGroup, NeuronModelBase, NeuronVar, - ParamVal, PlogSeverity, PostsynapticModelBase, + ModelSpecInternal, NeuronGroup, NeuronModelBase, ParamVal, + PlogSeverity, PostsynapticModelBase, SparseConnectivityInit, SynapseGroup, SynapseMatrixType, - SynapseVar, ToeplitzConnectivityInit, UnresolvedType, + ToeplitzConnectivityInit, UnresolvedType, Var, VarInit, VarLocation, VarRef, WeightUpdateModelBase) from .shared_library_model import (SharedLibraryModelDouble, SharedLibraryModelFloat) @@ -959,7 +959,7 @@ def create_neuron_model(class_name, param_names=None, if var_name_types is not None: body["get_vars"] = \ - lambda self: [NeuronVar(*vn) for vn in var_name_types] + lambda self: [Var(*vn) for vn in var_name_types] if is_auto_refractory_required is not None: body["is_auto_refractory_required"] = \ @@ -1006,7 +1006,7 @@ def create_postsynaptic_model(class_name, param_names=None, if var_name_types is not None: body["get_vars"] = \ - lambda self: [NeuronVar(*vn) for vn in var_name_types] + lambda self: [Var(*vn) for vn in var_name_types] return create_model(class_name, PostsynapticModelBase, param_names, derived_params, extra_global_params, body) @@ -1101,15 +1101,15 @@ def create_weight_update_model(class_name, param_names=None, if var_name_types is not None: body["get_vars"] = \ - lambda self: [SynapseVar(*vn) for vn in var_name_types] + lambda self: [Var(*vn) for vn in var_name_types] if pre_var_name_types is not None: body["get_pre_vars"] = \ - lambda self: [NeuronVar(*vn) for vn in pre_var_name_types] + lambda self: [Var(*vn) for vn in pre_var_name_types] if post_var_name_types is not None: body["get_post_vars"] = \ - lambda self: [NeuronVar(*vn) for vn in post_var_name_types] + lambda self: [Var(*vn) for vn in post_var_name_types] return create_model(class_name, WeightUpdateModelBase, param_names, derived_params, extra_global_params, body) @@ -1149,7 +1149,7 @@ def create_current_source_model(class_name, param_names=None, if var_name_types is not None: body["get_vars"] = \ - lambda self: [NeuronVar(*vn) for vn in var_name_types] + lambda self: [Var(*vn) for vn in var_name_types] return create_model(class_name, CurrentSourceModelBase, param_names, derived_params, extra_global_params, body) @@ -1258,15 +1258,15 @@ def create_custom_connectivity_update_model(class_name, if var_name_types is not None: body["get_vars"] = \ - lambda self: [SynapseVar(*vn) for vn in var_name_types] + lambda self: [Var(*vn) for vn in var_name_types] if pre_var_name_types is not None: body["get_pre_vars"] = \ - lambda self: [NeuronVar(*vn) for vn in pre_var_name_types] + lambda self: [Var(*vn) for vn in pre_var_name_types] if post_var_name_types is not None: body["get_post_vars"] = \ - lambda self: [NeuronVar(*vn) for vn in post_var_name_types] + lambda self: [Var(*vn) for vn in post_var_name_types] if var_refs is not None: body["get_var_refs"] = lambda self: [VarRef(*v) for v in var_refs] diff --git a/pygenn/src/genn.cc b/pygenn/src/genn.cc index dde31c37c5..849e5e05ef 100644 --- a/pygenn/src/genn.cc +++ b/pygenn/src/genn.cc @@ -100,7 +100,7 @@ class PyCurrentSourceModelBase : public PySnippet using Base = CurrentSourceModels::Base; public: virtual std::string getInjectionCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_injection_code", getInjectionCode); } - virtual std::vector getVars() const override{ PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_vars", getVars); } + virtual std::vector getVars() const override{ PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_vars", getVars); } }; //---------------------------------------------------------------------------- @@ -111,9 +111,9 @@ class PyCustomConnectivityUpdateModelBase : public PySnippet getVars() const override{ PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_vars", getVars); } - virtual std::vector getPreVars() const override { PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_pre_vars", getPreVars); } - virtual std::vector getPostVars() const override { PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_post_vars", getPostVars); } + virtual std::vector getVars() const override{ PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_vars", getVars); } + virtual std::vector getPreVars() const override { PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_pre_vars", getPreVars); } + virtual std::vector getPostVars() const override { PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_post_vars", getPostVars); } virtual VarRefVec getVarRefs() const override { PYBIND11_OVERRIDE_NAME(VarRefVec, Base, "get_var_refs", getVarRefs); } virtual VarRefVec getPreVarRefs() const override { PYBIND11_OVERRIDE_NAME(VarRefVec, Base, "get_pre_var_refs", getPreVarRefs); } @@ -149,7 +149,7 @@ class PyNeuronModelBase : public PySnippet virtual std::string getThresholdConditionCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_threshold_condition_code", getThresholdConditionCode); } virtual std::string getResetCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_reset_code", getResetCode); } - virtual std::vector getVars() const override{ PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_vars", getVars); } + virtual std::vector getVars() const override{ PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_vars", getVars); } virtual Models::Base::ParamValVec getAdditionalInputVars() const override { PYBIND11_OVERRIDE_NAME(Models::Base::ParamValVec, Base, "get_additional_input_vars", getAdditionalInputVars); } virtual bool isAutoRefractoryRequired() const override { PYBIND11_OVERRIDE_NAME(bool, Base, "is_auto_refractory_required", isAutoRefractoryRequired); } @@ -163,7 +163,7 @@ class PyPostsynapticModelBase : public PySnippet { using Base = PostsynapticModels::Base; public: - virtual std::vector getVars() const override{ PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_vars", getVars); } + virtual std::vector getVars() const override{ PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_vars", getVars); } virtual std::string getDecayCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_decay_code", getDecayCode); } virtual std::string getApplyInputCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_apply_input_code", getApplyInputCode); } @@ -187,9 +187,9 @@ class PyWeightUpdateModelBase : public PySnippet virtual std::string getPreDynamicsCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_pre_dynamics_code", getPreDynamicsCode); } virtual std::string getPostDynamicsCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_post_dynamics_code", getPostDynamicsCode); } - virtual std::vector getVars() const override{ PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_vars", getVars); } - virtual std::vector getPreVars() const override { PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_pre_vars", getPreVars); } - virtual std::vector getPostVars() const override { PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_post_vars", getPostVars); } + virtual std::vector getVars() const override{ PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_vars", getVars); } + virtual std::vector getPreVars() const override { PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_pre_vars", getPreVars); } + virtual std::vector getPostVars() const override { PYBIND11_OVERRIDE_NAME(std::vector, Base, "get_post_vars", getPostVars); } }; CodeGenerator::MemAlloc generateCode(ModelSpecInternal &model, CodeGenerator::BackendBase &backend, @@ -285,36 +285,25 @@ PYBIND11_MODULE(genn, m) //! Flags defining dimensions this variables has pybind11::enum_(m, "VarAccessDim") - .value("NEURON", VarAccessDim::NEURON) - .value("PRE_NEURON", VarAccessDim::PRE_NEURON) - .value("POST_NEURON", VarAccessDim::POST_NEURON) + .value("ELEMENT", VarAccessDim::ELEMENT) .value("BATCH", VarAccessDim::BATCH) .def("__and__", [](VarAccessDim a, VarAccessDim b){ return a & b; }, pybind11::is_operator()); //! Supported combinations of access mode and dimension for neuron variables - pybind11::enum_(m, "NeuronVarAccess") - .value("READ_WRITE", NeuronVarAccess::READ_WRITE) - .value("READ_ONLY", NeuronVarAccess::READ_ONLY) - .value("READ_ONLY_DUPLICATE", NeuronVarAccess::READ_ONLY_DUPLICATE) - .value("READ_ONLY_SHARED_NEURON", NeuronVarAccess::READ_ONLY_SHARED_NEURON); - - //! Supported combinations of access mode and dimension for synapse variables - pybind11::enum_(m, "SynapseVarAccess") - .value("READ_WRITE", SynapseVarAccess::READ_WRITE) - .value("READ_ONLY", SynapseVarAccess::READ_ONLY) - .value("READ_ONLY_DUPLICATE", SynapseVarAccess::READ_ONLY_DUPLICATE); - + pybind11::enum_(m, "VarAccess") + .value("READ_WRITE", VarAccess::READ_WRITE) + .value("READ_ONLY", VarAccess::READ_ONLY) + .value("READ_ONLY_DUPLICATE", VarAccess::READ_ONLY_DUPLICATE) + .value("READ_ONLY_SHARED_NEURON", VarAccess::READ_ONLY_SHARED_NEURON); //! Supported combinations of access mode and dimension for custom update variables /*! The axes are defined 'subtractively' ie VarAccessDim::BATCH indicates that this axis should be removed */ pybind11::enum_(m, "CustomUpdateVarAccess") .value("READ_WRITE", CustomUpdateVarAccess::READ_WRITE) .value("READ_ONLY", CustomUpdateVarAccess::READ_ONLY) - .value("READ_WRITE_SHARED", CustomUpdateVarAccess::READ_WRITE_SHARED) .value("READ_ONLY_SHARED", CustomUpdateVarAccess::READ_ONLY_SHARED) - .value("READ_WRITE_SHARED_NEURON", CustomUpdateVarAccess::READ_WRITE_SHARED_NEURON) .value("READ_ONLY_SHARED_NEURON", CustomUpdateVarAccess::READ_ONLY_SHARED_NEURON) .value("REDUCE_BATCH_SUM", CustomUpdateVarAccess::REDUCE_BATCH_SUM) .value("REDUCE_BATCH_MAX", CustomUpdateVarAccess::REDUCE_BATCH_MAX) @@ -360,8 +349,7 @@ PYBIND11_MODULE(genn, m) m.def("create_egp_ref", pybind11::overload_cast(&createEGPRef), pybind11::return_value_policy::move); m.def("create_psm_egp_ref", pybind11::overload_cast(&createPSMEGPRef), pybind11::return_value_policy::move); m.def("create_wu_egp_ref", pybind11::overload_cast(&createWUEGPRef), pybind11::return_value_policy::move); - m.def("get_var_access_dim", pybind11::overload_cast(&getVarAccessDim)); - m.def("get_var_access_dim", pybind11::overload_cast(&getVarAccessDim)); + m.def("get_var_access_dim", pybind11::overload_cast(&getVarAccessDim)); m.def("get_var_access_dim", pybind11::overload_cast(&getVarAccessDim)); //------------------------------------------------------------------------ @@ -693,29 +681,17 @@ PYBIND11_MODULE(genn, m) .def("get_code", &InitVarSnippet::Base::getCode); //------------------------------------------------------------------------ - // genn.NeuronVar + // genn.Var //------------------------------------------------------------------------ - pybind11::class_(m, "NeuronVar") - .def(pybind11::init()) + pybind11::class_(m, "Var") + .def(pybind11::init()) .def(pybind11::init()) - .def(pybind11::init()) + .def(pybind11::init()) .def(pybind11::init()) - .def_readonly("name", &Models::Base::NeuronVar::name) - .def_readonly("type", &Models::Base::NeuronVar::type) - .def_readonly("access", &Models::Base::NeuronVar::access); - - //------------------------------------------------------------------------ - // genn.SynapseVar - //------------------------------------------------------------------------ - pybind11::class_(m, "SynapseVar") - .def(pybind11::init()) - .def(pybind11::init()) - .def(pybind11::init()) - .def(pybind11::init()) - .def_readonly("name", &Models::Base::SynapseVar::name) - .def_readonly("type", &Models::Base::SynapseVar::type) - .def_readonly("access", &Models::Base::SynapseVar::access); - + .def_readonly("name", &Models::Base::Var::name) + .def_readonly("type", &Models::Base::Var::type) + .def_readonly("access", &Models::Base::Var::access); + //------------------------------------------------------------------------ // genn.CustomUpdateVar //------------------------------------------------------------------------ diff --git a/tests/features/test_custom_connectivity_update.py b/tests/features/test_custom_connectivity_update.py index 99d4a36717..25d422c851 100644 --- a/tests/features/test_custom_connectivity_update.py +++ b/tests/features/test_custom_connectivity_update.py @@ -3,7 +3,7 @@ from pygenn import types from pygenn import GeNNModel -from pygenn.genn import SynapseVarAccess, VarAccessMode +from pygenn.genn import VarAccessMode, VarAccess from bitarray import bitarray from bitarray.util import hex2ba @@ -26,8 +26,8 @@ weight_update_model = create_weight_update_model( "weight_update", - var_name_types=[("g", "scalar", SynapseVarAccess.READ_ONLY_DUPLICATE), - ("d", "unsigned int", SynapseVarAccess.READ_ONLY)]) + var_name_types=[("g", "scalar", VarAccess.READ_ONLY_DUPLICATE), + ("d", "unsigned int", VarAccess.READ_ONLY)]) # Snippet to initialise variable to hold its column-major index diff --git a/tests/features/test_custom_update.py b/tests/features/test_custom_update.py index 2e6ecba327..189eb23dc9 100644 --- a/tests/features/test_custom_update.py +++ b/tests/features/test_custom_update.py @@ -3,8 +3,7 @@ from pygenn import types from pygenn import GeNNModel -from pygenn.genn import (CustomUpdateVarAccess, NeuronVarAccess, - SynapseVarAccess, VarAccessMode) +from pygenn.genn import CustomUpdateVarAccess, VarAccess, VarAccessMode from scipy.special import softmax from pygenn import (create_current_source_model, @@ -28,26 +27,26 @@ def test_custom_update(backend, precision, batch_size): neuron_model = create_neuron_model( "neuron", - var_name_types=[("X", "scalar", NeuronVarAccess.READ_ONLY_DUPLICATE), - ("XShared", "scalar", NeuronVarAccess.READ_ONLY_SHARED_NEURON)]) + var_name_types=[("X", "scalar", VarAccess.READ_ONLY_DUPLICATE), + ("XShared", "scalar", VarAccess.READ_ONLY_SHARED_NEURON)]) current_source_model = create_current_source_model( "current_source", - var_name_types=[("X", "scalar", NeuronVarAccess.READ_ONLY_DUPLICATE), - ("XShared", "scalar", NeuronVarAccess.READ_ONLY_SHARED_NEURON)]) + var_name_types=[("X", "scalar", VarAccess.READ_ONLY_DUPLICATE), + ("XShared", "scalar", VarAccess.READ_ONLY_SHARED_NEURON)]) weight_update_model = create_weight_update_model( "weight_update", - var_name_types=[("X", "scalar", SynapseVarAccess.READ_ONLY_DUPLICATE)], - pre_var_name_types=[("preX", "scalar", NeuronVarAccess.READ_ONLY_DUPLICATE), - ("preXShared", "scalar", NeuronVarAccess.READ_ONLY_SHARED_NEURON)], - post_var_name_types=[("postX", "scalar", NeuronVarAccess.READ_ONLY_DUPLICATE), - ("postXShared", "scalar", NeuronVarAccess.READ_ONLY_SHARED_NEURON)]) + var_name_types=[("X", "scalar", VarAccess.READ_ONLY_DUPLICATE)], + pre_var_name_types=[("preX", "scalar", VarAccess.READ_ONLY_DUPLICATE), + ("preXShared", "scalar", VarAccess.READ_ONLY_SHARED_NEURON)], + post_var_name_types=[("postX", "scalar", VarAccess.READ_ONLY_DUPLICATE), + ("postXShared", "scalar", VarAccess.READ_ONLY_SHARED_NEURON)]) postsynaptic_update_model = create_postsynaptic_model( "postsynaptic_update", - var_name_types=[("psmX", "scalar", NeuronVarAccess.READ_ONLY_DUPLICATE), - ("psmXShared", "scalar", NeuronVarAccess.READ_ONLY_SHARED_NEURON)]) + var_name_types=[("psmX", "scalar", VarAccess.READ_ONLY_DUPLICATE), + ("psmXShared", "scalar", VarAccess.READ_ONLY_SHARED_NEURON)]) custom_update_model = create_custom_update_model( "custom_update", @@ -216,7 +215,7 @@ def test_custom_update(backend, precision, batch_size): def test_custom_update_transpose(backend, precision, batch_size): static_pulse_duplicate_model = create_weight_update_model( "static_pulse_duplicate", - var_name_types=[("g", "scalar", SynapseVarAccess.READ_ONLY_DUPLICATE)], + var_name_types=[("g", "scalar", VarAccess.READ_ONLY_DUPLICATE)], sim_code= """ addToPost(g); @@ -272,8 +271,8 @@ def test_custom_update_transpose(backend, precision, batch_size): def test_custom_update_neuron_reduce(backend, precision, batch_size): reduction_neuron_model = create_neuron_model( "reduction_neuron", - var_name_types=[("X", "scalar", NeuronVarAccess.READ_ONLY_DUPLICATE), - ("Y", "scalar", NeuronVarAccess.READ_ONLY_DUPLICATE)]) + var_name_types=[("X", "scalar", VarAccess.READ_ONLY_DUPLICATE), + ("Y", "scalar", VarAccess.READ_ONLY_DUPLICATE)]) softmax_1_custom_update_model = create_custom_update_model( "softmax_1", @@ -353,13 +352,13 @@ def test_custom_update_batch_reduction(backend, precision, batch_size): # **TODO** once VarAccess is refactored, we should really be able to reduce neuron shared across batch dimension neuron_model = create_neuron_model( "neuron", - var_name_types=[("X", "scalar", NeuronVarAccess.READ_ONLY_DUPLICATE), - ("SumX", "scalar", NeuronVarAccess.READ_ONLY)]) + var_name_types=[("X", "scalar", VarAccess.READ_ONLY_DUPLICATE), + ("SumX", "scalar", VarAccess.READ_ONLY)]) weight_update_model = create_weight_update_model( "weight_update", - var_name_types=[("X", "scalar", SynapseVarAccess.READ_ONLY_DUPLICATE), - ("SumX", "scalar", SynapseVarAccess.READ_ONLY)]) + var_name_types=[("X", "scalar", VarAccess.READ_ONLY_DUPLICATE), + ("SumX", "scalar", VarAccess.READ_ONLY)]) reduction_custom_update_model = create_custom_update_model( "reduction_custom_update", diff --git a/tests/features/test_spike_propagation.py b/tests/features/test_spike_propagation.py index 7e8698ae4a..5ab728e458 100644 --- a/tests/features/test_spike_propagation.py +++ b/tests/features/test_spike_propagation.py @@ -4,7 +4,7 @@ from pygenn import GeNNModel -from pygenn.genn import NeuronVarAccess, SpanType, SynapseVarAccess +from pygenn.genn import SpanType, VarAccess from pygenn import (create_neuron_model, create_sparse_connect_init_snippet, create_var_init_snippet, @@ -511,7 +511,7 @@ def test_reverse(backend, precision): pre_reverse_spike_source_model = create_neuron_model( "pre_reverse_spike_source", var_name_types=[("startSpike", "unsigned int"), - ("endSpike", "unsigned int", NeuronVarAccess.READ_ONLY_DUPLICATE), + ("endSpike", "unsigned int", VarAccess.READ_ONLY_DUPLICATE), ("x", "scalar")], extra_global_params=[("spikeTimes", "scalar*")], sim_code= @@ -533,7 +533,7 @@ def test_reverse(backend, precision): """ $(addToPre, $(g)); """, - var_name_types=[("g", "scalar", SynapseVarAccess.READ_ONLY)]) + var_name_types=[("g", "scalar", VarAccess.READ_ONLY)]) model = GeNNModel(precision, "test_reverse", backend=backend) model.dt = 1.0 @@ -617,7 +617,7 @@ def test_reverse_post(backend, precision): """ $(addToPre, $(g)); """, - var_name_types=[("g", "scalar", SynapseVarAccess.READ_ONLY)]) + var_name_types=[("g", "scalar", VarAccess.READ_ONLY)]) model = GeNNModel(precision, "test_reverse_post", backend=backend) model.dt = 1.0 diff --git a/tests/features/test_wu_vars.py b/tests/features/test_wu_vars.py index e972f99ce3..deeb6ea2c1 100644 --- a/tests/features/test_wu_vars.py +++ b/tests/features/test_wu_vars.py @@ -5,7 +5,7 @@ from pygenn import GeNNModel -from pygenn.genn import NeuronVarAccess +from pygenn.genn import VarAccess from pygenn import (create_neuron_model, create_weight_update_model, init_sparse_connectivity, init_var) @@ -220,7 +220,7 @@ def test_wu_var_cont(backend, precision, fuse, delay): pre_learn_post_weight_update_model = create_weight_update_model( "pre_learn_post_weight_update", var_name_types=[("w", "scalar")], - pre_var_name_types=[("s", "scalar"), ("shift", "scalar", NeuronVarAccess.READ_ONLY)], + pre_var_name_types=[("s", "scalar"), ("shift", "scalar", VarAccess.READ_ONLY)], learn_post_code= """ @@ -234,7 +234,7 @@ def test_wu_var_cont(backend, precision, fuse, delay): pre_sim_weight_update_model = create_weight_update_model( "pre_sim_weight_update", var_name_types=[("w", "scalar")], - pre_var_name_types=[("s", "scalar"), ("shift", "scalar", NeuronVarAccess.READ_ONLY)], + pre_var_name_types=[("s", "scalar"), ("shift", "scalar", VarAccess.READ_ONLY)], sim_code= """ @@ -251,7 +251,7 @@ def test_wu_var_cont(backend, precision, fuse, delay): post_learn_post_weight_update_model = create_weight_update_model( "post_learn_post_weight_update", var_name_types=[("w", "scalar")], - post_var_name_types=[("s", "scalar"), ("shift", "scalar", NeuronVarAccess.READ_ONLY)], + post_var_name_types=[("s", "scalar"), ("shift", "scalar", VarAccess.READ_ONLY)], learn_post_code= """ @@ -265,7 +265,7 @@ def test_wu_var_cont(backend, precision, fuse, delay): post_sim_weight_update_model = create_weight_update_model( "post_sim_weight_update", var_name_types=[("w", "scalar")], - post_var_name_types=[("s", "scalar"), ("shift", "scalar", NeuronVarAccess.READ_ONLY)], + post_var_name_types=[("s", "scalar"), ("shift", "scalar", VarAccess.READ_ONLY)], sim_code= """ From d3d12c9829b401a4880f1a9f99097b5c1c541ece Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 31 Oct 2023 11:36:17 +0000 Subject: [PATCH 60/60] moved one macro around --- include/genn/genn/customUpdateModels.h | 1 + include/genn/genn/models.h | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/include/genn/genn/customUpdateModels.h b/include/genn/genn/customUpdateModels.h index 8b26b46caa..12430d717d 100644 --- a/include/genn/genn/customUpdateModels.h +++ b/include/genn/genn/customUpdateModels.h @@ -7,6 +7,7 @@ //---------------------------------------------------------------------------- // Macros //---------------------------------------------------------------------------- +#define SET_CUSTOM_UPDATE_VARS(...) virtual std::vector getVars() const override{ return __VA_ARGS__; } #define SET_VAR_REFS(...) virtual VarRefVec getVarRefs() const override{ return __VA_ARGS__; } #define SET_EXTRA_GLOBAL_PARAM_REFS(...) virtual EGPRefVec getExtraGlobalParamRefs() const override{ return __VA_ARGS__; } #define SET_UPDATE_CODE(UPDATE_CODE) virtual std::string getUpdateCode() const override{ return UPDATE_CODE; } diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index e9c60a5b54..4943b224f0 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -34,7 +34,6 @@ class CustomConnectivityUpdateInternal; // Macros //---------------------------------------------------------------------------- #define SET_VARS(...) virtual std::vector getVars() const override{ return __VA_ARGS__; } -#define SET_CUSTOM_UPDATE_VARS(...) virtual std::vector getVars() const override{ return __VA_ARGS__; } #define DEFINE_REF_DETAIL_STRUCT(NAME, GROUP_TYPE, VAR_TYPE) using NAME = Detail //----------------------------------------------------------------------------