Skip to content

Commit

Permalink
Merge pull request #478 from genn-team/procedural_kernelg
Browse files Browse the repository at this point in the history
Proper implementation of kernel weights
  • Loading branch information
neworderofjamie authored Nov 18, 2021
2 parents ed144cd + 0234b7e commit dcfd43f
Show file tree
Hide file tree
Showing 38 changed files with 944 additions and 89 deletions.
1 change: 1 addition & 0 deletions include/genn/backends/single_threaded_cpu/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class BACKEND_EXPORT Backend : public BackendBase
const Substitutions &kernelSubs, Handler handler) const override;
virtual void genSparseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const override;
virtual void genDenseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const override;
virtual void genKernelSynapseVariableInit(CodeStream &os, const SynapseKernelInitGroupMerged &sg, const Substitutions &kernelSubs, Handler handler) const final;

virtual void genVariablePush(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const override;
virtual void genVariablePull(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, size_t count) const override;
Expand Down
2 changes: 2 additions & 0 deletions include/genn/genn/code_generator/backendBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ namespace CodeGenerator
class SynapseConnectivityInitGroupMerged;
class SynapseDenseInitGroupMerged;
class SynapseSparseInitGroupMerged;
class SynapseKernelInitGroupMerged;

}

Expand Down Expand Up @@ -281,6 +282,7 @@ class GENN_EXPORT BackendBase
const Substitutions &kernelSubs, Handler handler) const = 0;
virtual void genSparseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const = 0;
virtual void genDenseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const = 0;
virtual void genKernelSynapseVariableInit(CodeStream &os, const SynapseKernelInitGroupMerged &sg, const Substitutions &kernelSubs, Handler handler) const = 0;

//! Generate code for pushing a variable to the 'device'
virtual void genVariablePush(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const = 0;
Expand Down
3 changes: 2 additions & 1 deletion include/genn/genn/code_generator/backendSIMT.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ class GENN_EXPORT BackendSIMT : public BackendBase
{
genSynapseVariableRowInit(os, kernelSubs, handler);
}


virtual void genKernelSynapseVariableInit(CodeStream &os, const SynapseKernelInitGroupMerged &sg, const Substitutions &kernelSubs, Handler handler) const final;

//! Should 'scalar' variables be implemented on device or can host variables be used directly?
virtual bool isDeviceScalarRequired() const final { return true; }
Expand Down
27 changes: 0 additions & 27 deletions include/genn/genn/code_generator/codeGenUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,31 +147,4 @@ void neuronSubstitutionsInSynapticCode(CodeGenerator::Substitutions &substitutio
// Substitute extra global parameters from neuron model
substitutions.addVarNameSubstitution(nm->getExtraGlobalParams(), sourceSuffix, "group->", destSuffix);
}

template<typename G>
void genKernelIndex(std::ostream &os, const CodeGenerator::Substitutions &subs, const G &sg)
{
// Loop through kernel dimensions to calculate array index
const auto &kernelSize = sg.getArchetype().getKernelSize();
for(size_t i = 0; i < kernelSize.size(); i++) {
os << "(" << subs["id_kernel_" + std::to_string(i)];
// Loop through remainining dimensions of kernel
for(size_t j = i + 1; j < kernelSize.size(); j++) {
// If kernel size if heterogeneous in this dimension, multiply by value from group structure
if(sg.isKernelSizeHeterogeneous(j)) {
os << " * group->kernelSize" << j;
}
// Otherwise, multiply by literal
else {
os << " * " << kernelSize.at(j);
}
}
os << ")";

// If this isn't the last dimension, add +
if(i != (kernelSize.size() - 1)) {
os << " + ";
}
}
}
} // namespace CodeGenerator
8 changes: 8 additions & 0 deletions include/genn/genn/code_generator/groupMerged.h
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,12 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged<SynapseGroupIntern

//! Is kernel size heterogeneous in this dimension?
bool isKernelSizeHeterogeneous(size_t dimensionIndex) const;

//! Get expression for kernel size in dimension (may be literal or group->kernelSizeXXX)
std::string getKernelSize(size_t dimensionIndex) const;

//! Generate an index into a kernel based on the id_kernel_XXX variables in subs
void genKernelIndex(std::ostream &os, const CodeGenerator::Substitutions &subs) const;

std::string getPreSlot(unsigned int batchSize) const;
std::string getPostSlot(unsigned int batchSize) const;
Expand Down Expand Up @@ -1090,6 +1096,7 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged<SynapseGroupIntern
}

static std::string getSynVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index);
static std::string getKernelVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index);

protected:
//----------------------------------------------------------------------------
Expand All @@ -1102,6 +1109,7 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged<SynapseGroupIntern
SynapseDynamics,
DenseInit,
SparseInit,
KernelInit,
ConnectivityInit,
};

Expand Down
34 changes: 33 additions & 1 deletion include/genn/genn/code_generator/initGroupMerged.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,38 @@ class GENN_EXPORT SynapseSparseInitGroupMerged : public SynapseGroupMergedBase
static const std::string name;
};

//----------------------------------------------------------------------------
// CodeGenerator::SynapseKernelInitGroupMerged
//----------------------------------------------------------------------------
class GENN_EXPORT SynapseKernelInitGroupMerged : public SynapseGroupMergedBase
{
public:
SynapseKernelInitGroupMerged(size_t index, const std::string &precision, const std::string &timePrecision, const BackendBase &backend,
const std::vector<std::reference_wrapper<const SynapseGroupInternal>> &groups)
: SynapseGroupMergedBase(index, precision, timePrecision, backend, SynapseGroupMergedBase::Role::KernelInit, "", groups)
{}

boost::uuids::detail::sha1::digest_type getHashDigest() const
{
return SynapseGroupMergedBase::getHashDigest(SynapseGroupMergedBase::Role::KernelInit);
}

void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal,
CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar,
CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const
{
generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar,
runnerVarDecl, runnerMergedStructAlloc, name);
}

void generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const;

//----------------------------------------------------------------------------
// Static constants
//----------------------------------------------------------------------------
static const std::string name;
};


// ----------------------------------------------------------------------------
// CodeGenerator::SynapseConnectivityInitGroupMerged
Expand Down Expand Up @@ -384,4 +416,4 @@ class GENN_EXPORT CustomWUUpdateSparseInitGroupMerged : public CustomUpdateInitG
static const std::string name;
};

} // namespace CodeGenerator
} // namespace CodeGenerator
7 changes: 7 additions & 0 deletions include/genn/genn/code_generator/modelSpecMerged.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ class GENN_EXPORT ModelSpecMerged

//! Get merged synapse groups with sparse connectivity which require initialisation
const std::vector<SynapseSparseInitGroupMerged> &getMergedSynapseSparseInitGroups() const{ return m_MergedSynapseSparseInitGroups; }

//! Get merged synapse groups with kernel connectivity which require initialisation
const std::vector<SynapseKernelInitGroupMerged> &getMergedSynapseKernelInitGroups() const{ return m_MergedSynapseKernelInitGroups; }

//! Get merged custom update groups with sparse connectivity which require initialisation
const std::vector<CustomWUUpdateSparseInitGroupMerged> &getMergedCustomWUUpdateSparseInitGroups() const { return m_MergedCustomWUUpdateSparseInitGroups; }
Expand Down Expand Up @@ -170,6 +173,7 @@ class GENN_EXPORT ModelSpecMerged
void genMergedSynapseDenseInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedSynapseDenseInitGroups); }
void genMergedSynapseConnectivityInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedSynapseConnectivityInitGroups); }
void genMergedSynapseSparseInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedSynapseSparseInitGroups); }
void genMergedSynapseKernelInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedSynapseKernelInitGroups); }
void genMergedCustomWUUpdateSparseInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomWUUpdateSparseInitGroups); }
void genMergedNeuronSpikeQueueUpdateStructs(CodeStream &os, const BackendBase &backend) const{ genMergedStructures(os, backend, m_MergedNeuronSpikeQueueUpdateGroups); }
void genMergedNeuronPrevSpikeTimeUpdateStructs(CodeStream &os, const BackendBase &backend) const{ genMergedStructures(os, backend, m_MergedNeuronPrevSpikeTimeUpdateGroups); }
Expand Down Expand Up @@ -396,6 +400,9 @@ class GENN_EXPORT ModelSpecMerged

//! Merged synapse groups with sparse connectivity which require initialisation
std::vector<SynapseSparseInitGroupMerged> m_MergedSynapseSparseInitGroups;

//! Merged synapse groups with kernel connectivity which require initialisation
std::vector<SynapseKernelInitGroupMerged> m_MergedSynapseKernelInitGroups;

//! Merged custom update groups with sparse connectivity which require initialisation
std::vector<CustomWUUpdateSparseInitGroupMerged> m_MergedCustomWUUpdateSparseInitGroups;
Expand Down
3 changes: 2 additions & 1 deletion include/genn/genn/synapseGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ class GENN_EXPORT SynapseGroup
unsigned int getMaxDendriticDelayTimesteps() const{ return m_MaxDendriticDelayTimesteps; }
SynapseMatrixType getMatrixType() const{ return m_MatrixType; }
const std::vector<unsigned int> &getKernelSize() const { return m_KernelSize; }

size_t getKernelSizeFlattened() const;

//! Get variable mode used for variables used to combine input from this synapse group
VarLocation getInSynLocation() const { return m_InSynLocation; }

Expand Down
2 changes: 2 additions & 0 deletions include/genn/genn/synapseMatrixType.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ enum class SynapseMatrixWeight : unsigned int
INDIVIDUAL = (1 << 6),
PROCEDURAL = (1 << 7),
INDIVIDUAL_PSM = (1 << 8),
KERNEL = (1 << 9)
};

//! Supported combinations of SynapticMatrixConnectivity and SynapticMatrixWeight
Expand All @@ -36,6 +37,7 @@ enum class SynapseMatrixType : unsigned int
PROCEDURAL_GLOBALG = static_cast<unsigned int>(SynapseMatrixConnectivity::PROCEDURAL) | static_cast<unsigned int>(SynapseMatrixWeight::GLOBAL),
PROCEDURAL_GLOBALG_INDIVIDUAL_PSM = static_cast<unsigned int>(SynapseMatrixConnectivity::PROCEDURAL) | static_cast<unsigned int>(SynapseMatrixWeight::GLOBAL) | static_cast<unsigned int>(SynapseMatrixWeight::INDIVIDUAL_PSM),
PROCEDURAL_PROCEDURALG = static_cast<unsigned int>(SynapseMatrixConnectivity::PROCEDURAL) | static_cast<unsigned int>(SynapseMatrixWeight::PROCEDURAL) | static_cast<unsigned int>(SynapseMatrixWeight::INDIVIDUAL_PSM),
PROCEDURAL_KERNELG = static_cast<unsigned int>(SynapseMatrixConnectivity::PROCEDURAL) | static_cast<unsigned int>(SynapseMatrixWeight::KERNEL) | static_cast<unsigned int>(SynapseMatrixWeight::INDIVIDUAL_PSM),
};

//----------------------------------------------------------------------------
Expand Down
13 changes: 11 additions & 2 deletions pygenn/genn_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .genn_wrapper import (SynapseMatrixConnectivity_SPARSE,
SynapseMatrixConnectivity_BITMASK,
SynapseMatrixConnectivity_DENSE,
SynapseMatrixWeight_KERNEL,
SynapseMatrixWeight_INDIVIDUAL,
SynapseMatrixWeight_INDIVIDUAL_PSM,
VarLocation_HOST,
Expand Down Expand Up @@ -740,6 +741,8 @@ def weight_update_var_size(self):
return self.trg.size * self.src.size
elif self.is_ragged:
return self.max_row_length * self.src.size
elif self.has_kernel_synapse_vars:
return int(np.product(self.pop.get_kernel_size()))

@property
def max_row_length(self):
Expand Down Expand Up @@ -902,6 +905,12 @@ def has_individual_synapse_vars(self):
return (self.weight_sharing_master is None
and (self.matrix_type & SynapseMatrixWeight_INDIVIDUAL) != 0)

@property
def has_kernel_synapse_vars(self):
"""Tests whether synaptic connectivity has kernel weights"""
return (self.weight_sharing_master is None
and (self.matrix_type & SynapseMatrixWeight_KERNEL) != 0)

@property
def has_individual_postsynaptic_vars(self):
"""Tests whether synaptic connectivity has
Expand Down Expand Up @@ -1191,7 +1200,7 @@ def load(self):
var_data = self.vars[v.name]

# If population has individual synapse variables
if self.has_individual_synapse_vars:
if self.has_individual_synapse_vars or self.has_kernel_synapse_vars:
# If variable is located on host
var_loc = self.pop.get_wuvar_location(v.name)
if (var_loc & VarLocation_HOST) != 0:
Expand Down Expand Up @@ -1303,7 +1312,7 @@ def _init_wum_var(self, var_data, num_copies):
# If connectivity is dense,
# copy variables directly into view
# **NOTE** we assume order is row-major
if self.is_dense:
if self.is_dense or self.has_kernel_synapse_vars:
var_data.view[:] = var_data.values
elif self.is_ragged:
# Sort variable to match GeNN order
Expand Down
3 changes: 3 additions & 0 deletions src/genn/backends/cuda/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,7 @@ void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged,
modelMerged.genMergedCustomUpdateInitGroupStructs(os, *this);
modelMerged.genMergedCustomWUUpdateDenseInitGroupStructs(os, *this);
modelMerged.genMergedSynapseDenseInitGroupStructs(os, *this);
modelMerged.genMergedSynapseKernelInitGroupStructs(os, *this);
modelMerged.genMergedSynapseConnectivityInitGroupStructs(os, *this);
modelMerged.genMergedSynapseSparseInitGroupStructs(os, *this);
modelMerged.genMergedCustomWUUpdateSparseInitGroupStructs(os, *this);
Expand All @@ -868,6 +869,7 @@ void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged,
genMergedStructArrayPush(os, modelMerged.getMergedCustomUpdateInitGroups());
genMergedStructArrayPush(os, modelMerged.getMergedCustomWUUpdateDenseInitGroups());
genMergedStructArrayPush(os, modelMerged.getMergedSynapseDenseInitGroups());
genMergedStructArrayPush(os, modelMerged.getMergedSynapseKernelInitGroups());
genMergedStructArrayPush(os, modelMerged.getMergedSynapseConnectivityInitGroups());
genMergedStructArrayPush(os, modelMerged.getMergedSynapseSparseInitGroups());
genMergedStructArrayPush(os, modelMerged.getMergedCustomWUUpdateSparseInitGroups());
Expand All @@ -885,6 +887,7 @@ void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged,
modelMerged.getMergedCustomUpdateInitGroups(), [this](const CustomUpdateInternal &cg) { return padKernelSize(cg.getSize(), KernelInitialize); },
modelMerged.getMergedCustomWUUpdateDenseInitGroups(), [this](const CustomUpdateWUInternal &cg){ return padKernelSize(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons(), KernelInitialize); },
modelMerged.getMergedSynapseDenseInitGroups(), [this](const SynapseGroupInternal &sg){ return padKernelSize(sg.getTrgNeuronGroup()->getNumNeurons(), KernelInitialize); },
modelMerged.getMergedSynapseKernelInitGroups(), [this](const SynapseGroupInternal &sg){ return padKernelSize(sg.getKernelSizeFlattened(), KernelInitialize); },
modelMerged.getMergedSynapseConnectivityInitGroups(), [this](const SynapseGroupInternal &sg){ return padKernelSize(getNumConnectivityInitThreads(sg), KernelInitialize); });

// Generate data structure for accessing merged groups from within sparse initialisation kernel
Expand Down
Loading

0 comments on commit dcfd43f

Please sign in to comment.