Skip to content

Commit

Permalink
Merge pull request #524 from genn-team/kernel_batch_reductions
Browse files Browse the repository at this point in the history
Custom updates on KERNEL variables
  • Loading branch information
neworderofjamie authored Jul 1, 2022
2 parents ea1c943 + 7468d32 commit 5e68ffe
Show file tree
Hide file tree
Showing 27 changed files with 922 additions and 774 deletions.
3 changes: 2 additions & 1 deletion include/genn/backends/single_threaded_cpu/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ class BACKEND_EXPORT Backend : public BackendBase
const Substitutions &kernelSubs, Handler handler) const override;
virtual void genSparseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const override;
virtual void genDenseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const override;
virtual void genKernelSynapseVariableInit(CodeStream &os, const SynapseKernelInitGroupMerged &sg, const Substitutions &kernelSubs, Handler handler) const final;
virtual void genKernelSynapseVariableInit(CodeStream &os, const SynapseInitGroupMerged &sg, const Substitutions &kernelSubs, Handler handler) const final;
virtual void genKernelCustomUpdateVariableInit(CodeStream &os, const CustomWUUpdateInitGroupMerged &cu, const Substitutions &kernelSubs, Handler handler) const final;

virtual void genVariablePush(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const override;
virtual void genVariablePull(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, size_t count) const override;
Expand Down
8 changes: 4 additions & 4 deletions include/genn/genn/code_generator/backendBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,11 @@ namespace CodeGenerator
class CustomUpdateTransposeWUGroupMerged;
class NeuronInitGroupMerged;
class CustomUpdateInitGroupMerged;
class CustomWUUpdateDenseInitGroupMerged;
class CustomWUUpdateInitGroupMerged;
class CustomWUUpdateSparseInitGroupMerged;
class SynapseConnectivityInitGroupMerged;
class SynapseDenseInitGroupMerged;
class SynapseInitGroupMerged;
class SynapseSparseInitGroupMerged;
class SynapseKernelInitGroupMerged;

}

Expand Down Expand Up @@ -282,7 +281,8 @@ class GENN_EXPORT BackendBase
const Substitutions &kernelSubs, Handler handler) const = 0;
virtual void genSparseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const = 0;
virtual void genDenseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const = 0;
virtual void genKernelSynapseVariableInit(CodeStream &os, const SynapseKernelInitGroupMerged &sg, const Substitutions &kernelSubs, Handler handler) const = 0;
virtual void genKernelSynapseVariableInit(CodeStream &os, const SynapseInitGroupMerged &sg, const Substitutions &kernelSubs, Handler handler) const = 0;
virtual void genKernelCustomUpdateVariableInit(CodeStream &os, const CustomWUUpdateInitGroupMerged &cu, const Substitutions &kernelSubs, Handler handler) const = 0;

//! Generate code for pushing a variable to the 'device'
virtual void genVariablePush(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const = 0;
Expand Down
148 changes: 142 additions & 6 deletions include/genn/genn/code_generator/backendSIMT.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ class GENN_EXPORT BackendSIMT : public BackendBase
genSynapseVariableRowInit(os, kernelSubs, handler);
}

virtual void genKernelSynapseVariableInit(CodeStream &os, const SynapseKernelInitGroupMerged &sg, const Substitutions &kernelSubs, Handler handler) const final;
virtual void genKernelSynapseVariableInit(CodeStream &os, const SynapseInitGroupMerged &sg, const Substitutions &kernelSubs, Handler handler) const final;
virtual void genKernelCustomUpdateVariableInit(CodeStream &os, const CustomWUUpdateInitGroupMerged &cu, const Substitutions &kernelSubs, Handler handler) const final;

//! Should 'scalar' variables be implemented on device or can host variables be used directly?
virtual bool isDeviceScalarRequired() const final { return true; }
Expand Down Expand Up @@ -165,6 +166,8 @@ class GENN_EXPORT BackendSIMT : public BackendBase
static size_t getNumPostsynapticUpdateThreads(const SynapseGroupInternal &sg);
static size_t getNumSynapseDynamicsThreads(const SynapseGroupInternal &sg);
static size_t getNumConnectivityInitThreads(const SynapseGroupInternal &sg);
static size_t getNumInitThreads(const SynapseGroupInternal &sg);
static size_t getNumInitThreads(const CustomUpdateWUInternal &cg);

//! Register a new presynaptic update strategy
/*! This function should be called with strategies in ascending order of preference */
Expand Down Expand Up @@ -321,6 +324,14 @@ class GENN_EXPORT BackendSIMT : public BackendBase
}


template<typename T, typename S>
void genParallelGroup(CodeStream &os, const Substitutions &kernelSubs, const std::vector<T> &groups, size_t &idStart,
S getPaddedSizeFunc, GroupHandler<T> handler) const
{
genParallelGroup(os, kernelSubs, groups, idStart, getPaddedSizeFunc,
[](const T &) { return true; }, handler);
}

template<typename G>
std::vector<ReductionTarget> genInitReductionTargets(CodeStream &os, const G &cg) const
{
Expand All @@ -345,13 +356,138 @@ class GENN_EXPORT BackendSIMT : public BackendBase
}
return reductionTargets;
}

// Helper function to generate kernel code to initialise variables associated with synapse group or custom WU update with dense/kernel connectivity
template<typename G>
void genSynapseVarInit(CodeStream &os, const ModelSpecMerged &modelMerged, const G &g, Substitutions &popSubs,
bool initRNGRequired, bool kernel, size_t kernelDimensions) const
{
os << "if(" << popSubs["id"] << " < ";

// If synapse group has kernel weights, check ID against product of kernel dimensions
if (kernel) {
// Loop through kernel dimensions and multiply together
os << "(";
for (size_t i = 0; i < kernelDimensions; i++) {
os << g.getKernelSize(i);
if (i != (kernelDimensions - 1)) {
os << " * ";
}
}
os << ")";
}
// Otherwise, against number of postsynaptic neurons
else {
os << "group->numTrgNeurons";
}
os << ")";
{
CodeStream::Scope b(os);

// If an RNG is required for initialisation,
// make copy of global phillox RNG and skip ahead by thread id
// **NOTE** not LOCAL id
if(initRNGRequired) {
genGlobalRNGSkipAhead(os, popSubs, "id");
}

template<typename T, typename S>
void genParallelGroup(CodeStream &os, const Substitutions &kernelSubs, const std::vector<T> &groups, size_t &idStart,
S getPaddedSizeFunc, GroupHandler<T> handler) const
// If synapse group has kernel weights
if (kernel) {
// Loop through kernel dimensions to generate seperate indices
for (size_t i = 0; i < kernelDimensions; i++) {
os << "const unsigned int kernelID" << i << " = (" << popSubs["id"];

// If this isn't the last dimension
if (i < (kernelDimensions - 1)) {
// Loop backwards through other kernel and generate code to divide by product of subsequent dimensions
os << " / (";
for (size_t j = (kernelDimensions - 1); j > i; j--) {
os << g.getKernelSize(j);

if (j != (i + 1)) {
os << " * ";
}
}
os << ")";
}
os << ")";

// If this isn't the first dimension, take modulus of kernel size
if (i > 0) {
os << " % " << g.getKernelSize(i);
}

os << ";" << std::endl;

// Add substitution
popSubs.addVarSubstitution("id_kernel_" + std::to_string(i), "kernelID" + std::to_string(i));
}
}
// Otherwise, just substitute postsynaptic index
else {
popSubs.addVarSubstitution("id_post", popSubs["id"]);
}

// Generate init code
g.generateInit(*this, os, modelMerged, popSubs);
}
}

// Helper function to generate kernel code to initialise variables associated with synapse group or custom WU update with sparse connectivity
template<typename G>
void genSparseSynapseVarInit(CodeStream &os, const ModelSpecMerged &modelMerged, const G &g, Substitutions &popSubs,
bool varInitRequired, GroupHandler<G> handler) const
{
genParallelGroup(os, kernelSubs, groups, idStart, getPaddedSizeFunc,
[](const T &) { return true; }, handler);
// Calculate how many blocks rows need to be processed in (in order to store row lengths in shared memory)
const size_t blockSize = getKernelBlockSize(KernelInitializeSparse);
os << "const unsigned int numBlocks = (group->numSrcNeurons + " << blockSize << " - 1) / " << blockSize << ";" << std::endl;

os << "unsigned int idx = " << popSubs["id"] << ";" << std::endl;

// Loop through blocks
os << "for(unsigned int r = 0; r < numBlocks; r++)";
{
CodeStream::Scope b(os);

// Calculate number of rows to process in this block
os << "const unsigned numRowsInBlock = (r == (numBlocks - 1))";
os << " ? ((group->numSrcNeurons - 1) % " << blockSize << ") + 1";
os << " : " << blockSize << ";" << std::endl;

// Use threads to copy block of sparse structure into shared memory
genSharedMemBarrier(os);
os << "if (" << getThreadID() << " < numRowsInBlock)";
{
CodeStream::Scope b(os);
os << "shRowLength[" << getThreadID() << "] = group->rowLength[(r * " << blockSize << ") + " << getThreadID() << "];" << std::endl;
}
genSharedMemBarrier(os);

// Loop through rows
os << "for(unsigned int i = 0; i < numRowsInBlock; i++)";
{
CodeStream::Scope b(os);

// If there is a synapse for this thread to initialise
os << "if(" << popSubs["id"] << " < shRowLength[i])";
{
CodeStream::Scope b(os);

// Generate initialisation code
if(varInitRequired) {
popSubs.addVarSubstitution("id_pre", "((r * " + std::to_string(blockSize) + ") + i)");
popSubs.addVarSubstitution("id_post", "group->ind[idx]");
g.generateInit(*this, os, modelMerged, popSubs);
}

// Call handler
handler(os, g, popSubs);
}

// If matrix is ragged, advance index to next row by adding stride
os << "idx += group->rowStride;" << std::endl;
}
}
}

void genEmitSpike(CodeStream &os, const Substitutions &subs, const std::string &suffix, bool recordingEnabled) const;
Expand Down
49 changes: 49 additions & 0 deletions include/genn/genn/code_generator/codeGenUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,53 @@ void neuronSubstitutionsInSynapticCode(CodeGenerator::Substitutions &substitutio
// Substitute extra global parameters from neuron model
substitutions.addVarNameSubstitution(nm->getExtraGlobalParams(), sourceSuffix, "group->", destSuffix);
}

template<typename G, typename K>
bool isKernelSizeHeterogeneous(const G *group, size_t dimensionIndex, K getKernelSizeFn)
{
// Get size of this kernel dimension for archetype
const unsigned archetypeValue = getKernelSizeFn(group->getArchetype()).at(dimensionIndex);

// Return true if any of the other groups have a different value
return std::any_of(group->getGroups().cbegin(), group->getGroups().cend(),
[archetypeValue, dimensionIndex, getKernelSizeFn]
(const typename G::GroupInternal& g)
{
return (getKernelSizeFn(g).at(dimensionIndex) != archetypeValue);
});
}

template<typename G, typename K>
std::string getKernelSize(const G *group, size_t dimensionIndex, K getKernelSizeFn)
{
// If kernel size if heterogeneous in this dimension, return group structure entry
if (isKernelSizeHeterogeneous(group, dimensionIndex, getKernelSizeFn)) {
return "group->kernelSize" + std::to_string(dimensionIndex);
}
// Otherwise, return literal
else {
return std::to_string(getKernelSizeFn(group->getArchetype()).at(dimensionIndex));
}
}

template<typename G, typename K>
void genKernelIndex(const G *group, std::ostream &os, const CodeGenerator::Substitutions &subs,
K getKernelSizeFn)
{
// Loop through kernel dimensions to calculate array index
const auto &kernelSize = getKernelSizeFn(group->getArchetype());
for (size_t i = 0; i < kernelSize.size(); i++) {
os << "(" << subs["id_kernel_" + std::to_string(i)];
// Loop through remainining dimensions of kernel and multiply
for (size_t j = i + 1; j < kernelSize.size(); j++) {
os << " * " << getKernelSize(group, j, getKernelSizeFn);
}
os << ")";

// If this isn't the last dimension, add +
if (i != (kernelSize.size() - 1)) {
os << " + ";
}
}
}
} // namespace CodeGenerator
24 changes: 24 additions & 0 deletions include/genn/genn/code_generator/customUpdateGroupMerged.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

// GeNN code generator includes
#include "code_generator/codeGenUtils.h"
#include "code_generator/groupMerged.h"

//----------------------------------------------------------------------------
Expand Down Expand Up @@ -61,10 +62,33 @@ class GENN_EXPORT CustomUpdateWUGroupMergedBase : public GroupMerged<CustomUpdat
std::string getVarIndex(VarAccessDuplication varDuplication, const std::string &index) const;
std::string getVarRefIndex(VarAccessDuplication varDuplication, const std::string &index) const;

//! Is kernel size heterogeneous in this dimension?
bool isKernelSizeHeterogeneous(size_t dimensionIndex) const
{
return CodeGenerator::isKernelSizeHeterogeneous(this, dimensionIndex, getGroupKernelSize);
}

//! Get expression for kernel size in dimension (may be literal or group->kernelSizeXXX)
std::string getKernelSize(size_t dimensionIndex) const
{
return CodeGenerator::getKernelSize(this, dimensionIndex, getGroupKernelSize);
}

//! Generate an index into a kernel based on the id_kernel_XXX variables in subs
void genKernelIndex(std::ostream& os, const CodeGenerator::Substitutions& subs) const
{
return CodeGenerator::genKernelIndex(this, os, subs, getGroupKernelSize);
}

protected:
CustomUpdateWUGroupMergedBase(size_t index, const std::string &precision, const std::string &, const BackendBase &backend,
const std::vector<std::reference_wrapper<const CustomUpdateWUInternal>> &groups);

private:
static const std::vector<unsigned int>& getGroupKernelSize(const CustomUpdateWUInternal& g)
{
return g.getSynapseGroup()->getKernelSize();
}
};

// ----------------------------------------------------------------------------
Expand Down
23 changes: 18 additions & 5 deletions include/genn/genn/code_generator/groupMerged.h
Original file line number Diff line number Diff line change
Expand Up @@ -1049,13 +1049,22 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged<SynapseGroupIntern
bool isTrgNeuronDerivedParamHeterogeneous(size_t paramIndex) const;

//! Is kernel size heterogeneous in this dimension?
bool isKernelSizeHeterogeneous(size_t dimensionIndex) const;
bool isKernelSizeHeterogeneous(size_t dimensionIndex) const
{
return CodeGenerator::isKernelSizeHeterogeneous(this, dimensionIndex, getGroupKernelSize);
}

//! Get expression for kernel size in dimension (may be literal or group->kernelSizeXXX)
std::string getKernelSize(size_t dimensionIndex) const;
std::string getKernelSize(size_t dimensionIndex) const
{
return CodeGenerator::getKernelSize(this, dimensionIndex, getGroupKernelSize);
}

//! Generate an index into a kernel based on the id_kernel_XXX variables in subs
void genKernelIndex(std::ostream &os, const CodeGenerator::Substitutions &subs) const;
void genKernelIndex(std::ostream& os, const CodeGenerator::Substitutions& subs) const
{
return CodeGenerator::genKernelIndex(this, os, subs, getGroupKernelSize);
}

std::string getPreSlot(unsigned int batchSize) const;
std::string getPostSlot(unsigned int batchSize) const;
Expand Down Expand Up @@ -1113,9 +1122,8 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged<SynapseGroupIntern
PresynapticUpdate,
PostsynapticUpdate,
SynapseDynamics,
DenseInit,
Init,
SparseInit,
KernelInit,
ConnectivityInit,
};

Expand Down Expand Up @@ -1178,6 +1186,11 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged<SynapseGroupIntern
//! Is postsynaptic neuron derived parameter referenced?
bool isTrgNeuronDerivedParamReferenced(size_t paramIndex) const;

static const std::vector<unsigned int>& getGroupKernelSize(const SynapseGroupInternal& g)
{
return g.getKernelSize();
}

//------------------------------------------------------------------------
// Members
//------------------------------------------------------------------------
Expand Down
Loading

0 comments on commit 5e68ffe

Please sign in to comment.