Skip to content

Commit

Permalink
Merge pull request #511 from genn-team/new_sparse_synapse_dynamics
Browse files Browse the repository at this point in the history
Replace synRemap mechanism with much simpler one
  • Loading branch information
neworderofjamie authored Apr 27, 2022
2 parents 22da550 + f3c1a0f commit d941d60
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 104 deletions.
1 change: 0 additions & 1 deletion include/genn/backends/single_threaded_cpu/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ class BACKEND_EXPORT Backend : public BackendBase
//! Different backends seed RNGs in different ways. Does this one initialise population RNGS on device?
virtual bool isPopulationRNGInitialisedOnDevice() const override { return false; }

virtual bool isSynRemapRequired(const SynapseGroupInternal&) const override{ return false; }
virtual bool isPostsynapticRemapRequired() const override{ return true; }

//! Backends which support batch-parallelism might require an additional host reduction phase after reduction kernels
Expand Down
3 changes: 0 additions & 3 deletions include/genn/genn/code_generator/backendBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,6 @@ class GENN_EXPORT BackendBase
//! Different backends seed RNGs in different ways. Does this one initialise population RNGS on device?
virtual bool isPopulationRNGInitialisedOnDevice() const = 0;

//! Different backends may implement synapse dynamics differently. Does this one require a synapse remapping data structure for synapse group?
virtual bool isSynRemapRequired(const SynapseGroupInternal &sg) const = 0;

//! Different backends may implement synaptic plasticity differently. Does this one require a postsynaptic remapping data structure?
virtual bool isPostsynapticRemapRequired() const = 0;

Expand Down
1 change: 0 additions & 1 deletion include/genn/genn/code_generator/backendSIMT.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ class GENN_EXPORT BackendSIMT : public BackendBase
virtual bool isGlobalDeviceRNGRequired(const ModelSpecMerged &modelMerged) const final;
virtual bool isPopulationRNGRequired() const final { return true; }

virtual bool isSynRemapRequired(const SynapseGroupInternal &sg) const final;
virtual bool isPostsynapticRemapRequired() const final { return true; }

//------------------------------------------------------------------------
Expand Down
96 changes: 29 additions & 67 deletions src/genn/genn/code_generator/backendSIMT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,6 @@ bool BackendSIMT::isGlobalDeviceRNGRequired(const ModelSpecMerged &modelMerged)
return false;
}
//--------------------------------------------------------------------------
bool BackendSIMT::isSynRemapRequired(const SynapseGroupInternal &sg) const
{
// This synapse group required synRemap if it's sparse and either has synapse dynamics or is targetted by any custom update
return ((sg.getMatrixType() & SynapseMatrixConnectivity::SPARSE) &&
(!sg.getWUModel()->getSynapseDynamicsCode().empty() || sg.areWUVarReferencedByCustomUpdate()));
}
//--------------------------------------------------------------------------
size_t BackendSIMT::getNumInitialisationRNGStreams(const ModelSpecMerged &modelMerged) const
{
// Calculate total number of threads used for neuron initialisation group
Expand Down Expand Up @@ -807,7 +800,7 @@ void BackendSIMT::genSynapseDynamicsKernel(CodeStream &os, const Substitutions &
Substitutions synSubs(&popSubs);

if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) {
os << "if (" << popSubs["id"] << " < group->synRemap[0])";
os << "if (" << popSubs["id"] << " < (group->numSrcNeurons * group->rowStride))";
}
else {
os << "if (" << popSubs["id"] << " < (group->numSrcNeurons * group->numTrgNeurons))";
Expand All @@ -816,12 +809,16 @@ void BackendSIMT::genSynapseDynamicsKernel(CodeStream &os, const Substitutions &
CodeStream::Scope b(os);

if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) {
// Determine synapse and presynaptic indices for this thread
os << "const unsigned int s = group->synRemap[1 + " << popSubs["id"] << "];" << std::endl;
// **OPTIMIZE * *we can do a fast constant divide optimization here and use the result to calculate the remainder
os << "const unsigned int row = " << popSubs["id"] << " / group->rowStride;" << std::endl;
os << "const unsigned int col = " << popSubs["id"] << " % group->rowStride;" << std::endl;

synSubs.addVarSubstitution("id_pre", "row");
synSubs.addVarSubstitution("id_post", "group->ind[" + popSubs["id"] + "]");
synSubs.addVarSubstitution("id_syn", popSubs["id"]);

synSubs.addVarSubstitution("id_pre", "(s / group->rowStride)");
synSubs.addVarSubstitution("id_post", "group->ind[s]");
synSubs.addVarSubstitution("id_syn", "s");
os << "if(col < group->rowLength[row])";
os << CodeStream::OB(1);
}
else {
// **OPTIMIZE** we can do a fast constant divide optimization here and use the result to calculate the remainder
Expand All @@ -846,6 +843,10 @@ void BackendSIMT::genSynapseDynamicsKernel(CodeStream &os, const Substitutions &
}

sg.generateSynapseUpdate(*this, os, modelMerged, synSubs);

if (sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) {
os << CodeStream::CB(1);
}
}
});
}
Expand Down Expand Up @@ -975,7 +976,7 @@ void BackendSIMT::genCustomUpdateWUKernel(CodeStream &os, const Substitutions &k
}

if(cg.getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) {
os << "if (" << cuSubs["id"] << " < group->synRemap[0])";
os << "if (" << cuSubs["id"] << " < (group->numSrcNeurons * group->rowStride))";
}
else {
os << "if (" << cuSubs["id"] << " < size)";
Expand All @@ -984,12 +985,16 @@ void BackendSIMT::genCustomUpdateWUKernel(CodeStream &os, const Substitutions &k
CodeStream::Scope b(os);

if(cg.getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) {
// Determine synapse and presynaptic indices for this thread
os << "const unsigned int s = group->synRemap[1 + " << cuSubs["id"] << "];" << std::endl;
// **OPTIMIZE * *we can do a fast constant divide optimization here and use the result to calculate the remainder
os << "const unsigned int row = " << cuSubs["id"] << " / group->rowStride;" << std::endl;
os << "const unsigned int col = " << cuSubs["id"] << " % group->rowStride;" << std::endl;

cuSubs.addVarSubstitution("id_pre", "(s / group->rowStride)");
cuSubs.addVarSubstitution("id_post", "group->ind[s]");
cuSubs.addVarSubstitution("id_syn", "s");
cuSubs.addVarSubstitution("id_pre", "row");
cuSubs.addVarSubstitution("id_post", "group->ind[" + cuSubs["id"] + "]");
cuSubs.addVarSubstitution("id_syn", cuSubs["id"]);

os << "if(col < group->rowLength[row])";
os << CodeStream::OB(2);
}
else {
// **OPTIMIZE** we can do a fast constant divide optimization here and use the result to calculate the remainder
Expand Down Expand Up @@ -1033,6 +1038,10 @@ void BackendSIMT::genCustomUpdateWUKernel(CodeStream &os, const Substitutions &k
os << "group->" << r.name << "[" << cuSubs["id_syn"] << "] = lr" << r.name << ";" << std::endl;
}
}

if (cg.getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) {
os << CodeStream::CB(2);
}
}
});
}
Expand Down Expand Up @@ -1495,12 +1504,7 @@ void BackendSIMT::genInitializeSparseKernel(CodeStream &os, const Substitutions
// Shared memory array so row lengths don't have to be read by EVERY postsynaptic thread
// **TODO** check actually required
os << getSharedPrefix() << "unsigned int shRowLength[" << getKernelBlockSize(KernelInitializeSparse) << "];" << std::endl;
if(std::any_of(modelMerged.getModel().getSynapseGroups().cbegin(), modelMerged.getModel().getSynapseGroups().cend(),
[this](const ModelSpec::SynapseGroupValueType &s) { return isSynRemapRequired(s.second); }))
{
os << getSharedPrefix() << "unsigned int shRowStart[" << getKernelBlockSize(KernelInitializeSparse) + 1 << "];" << std::endl;
}


// Initialise weight update variables for synapse groups with sparse connectivity
genParallelGroup<SynapseSparseInitGroupMerged>(os, kernelSubs, modelMerged.getMergedSynapseSparseInitGroups(), idStart,
[this](const SynapseGroupInternal &sg) { return padKernelSize(sg.getMaxConnections(), KernelInitializeSparse); },
Expand Down Expand Up @@ -1536,42 +1540,6 @@ void BackendSIMT::genInitializeSparseKernel(CodeStream &os, const Substitutions
CodeStream::Scope b(os);
os << "shRowLength[" << getThreadID() << "] = group->rowLength[(r * " << blockSize << ") + " << getThreadID() << "];" << std::endl;
}

// If this synapse group has synapse dynamics
if(isSynRemapRequired(sg.getArchetype())) {
genSharedMemBarrier(os);

// Use first thread to generate cumulative sum
os << "if(" << getThreadID() << " == 0)";
{
CodeStream::Scope b(os);

// Get index of last row in resultant synapse dynamics structure
// **NOTE** if there IS a previous block, it will always have had initSparseBlkSz rows in it
os << "unsigned int rowStart = (r == 0) ? 0 : shRowStart[" << blockSize << "];" << std::endl;
os << "shRowStart[0] = rowStart;" << std::endl;

// Loop through rows in block
os << "for(unsigned int i = 0; i < numRowsInBlock; i++)";
{
CodeStream::Scope b(os);

// Add this row's length to cumulative sum and write this to this row's end
os << "rowStart += shRowLength[i];" << std::endl;
os << "shRowStart[i + 1] = rowStart;" << std::endl;
}

// If this is the first thread block of the first block in the group AND the last block of rows,
// write the total cumulative sum to the first entry of the remap structure
os << "if(" << popSubs["id"] << " == 0 && (r == (numBlocks - 1)))";
{
CodeStream::Scope b(os);
os << "group->synRemap[0] = shRowStart[numRowsInBlock];" << std::endl;
}

}
}

genSharedMemBarrier(os);

// Loop through rows
Expand Down Expand Up @@ -1608,12 +1576,6 @@ void BackendSIMT::genInitializeSparseKernel(CodeStream &os, const Substitutions
// Add remapping entry at this location poining back to row-major index
os << "group->remap[colMajorIndex] = idx;" << std::endl;
}

// If synapse remap is required, copy idx into first entry of syn remap structure
if(isSynRemapRequired(sg.getArchetype())) {
CodeStream::Scope b(os);
os << "group->synRemap[shRowStart[i] + " + popSubs["id"] + " + 1] = idx;" << std::endl;
}
}

// If matrix is ragged, advance index to next row by adding stride
Expand Down
21 changes: 5 additions & 16 deletions src/genn/genn/code_generator/customUpdateGroupMerged.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,22 +281,11 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const
return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName();
});

// If the referenced synapse group requires synaptic remapping and matrix type is sparse, add field
if(backend.isSynRemapRequired(*getArchetype().getSynapseGroup())) {
addField("unsigned int*", "synRemap",
[&backend](const CustomUpdateWUInternal &cg, size_t)
{
return backend.getDeviceVarPrefix() + "synRemap" + cg.getSynapseGroup()->getName();
});
}
// Otherwise, add row length
else {
addField("unsigned int*", "rowLength",
[&backend](const CustomUpdateWUInternal &cg, size_t)
{
return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName();
});
}
addField("unsigned int*", "rowLength",
[&backend](const CustomUpdateWUInternal &cg, size_t)
{
return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName();
});
}

// Add heterogeneous custom update model parameters
Expand Down
8 changes: 0 additions & 8 deletions src/genn/genn/code_generator/generateRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1260,14 +1260,6 @@ MemAlloc CodeGenerator::generateRunner(const filesystem::path &outputPath, const
backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree,
s.second.getSparseIndType(), "ind" + s.second.getName(), varLoc, size, mem);


// If synapse remap structure is required, allocate synRemap
// **THINK** this is over-allocating
if(backend.isSynRemapRequired(s.second)) {
backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree,
"unsigned int", "synRemap" + s.second.getName(), VarLocation::DEVICE, size + 1, mem);
}

// **TODO** remap is not always required
if(backend.isPostsynapticRemapRequired() && !s.second.getWUModel()->getLearnPostCode().empty()) {
const size_t postSize = (size_t)s.second.getTrgNeuronGroup()->getNumNeurons() * (size_t)s.second.getMaxSourceConnections();
Expand Down
7 changes: 0 additions & 7 deletions src/genn/genn/code_generator/groupMerged.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1105,13 +1105,6 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string &
addWeightSharingPointerField("unsigned int", "colLength", backend.getDeviceVarPrefix() + "colLength");
addWeightSharingPointerField("unsigned int", "remap", backend.getDeviceVarPrefix() + "remap");
}

// Add additional structure for synapse dynamics access if required
if((role == Role::SynapseDynamics || role == Role::SparseInit) &&
backend.isSynRemapRequired(getArchetype()))
{
addWeightSharingPointerField("unsigned int", "synRemap", backend.getDeviceVarPrefix() + "synRemap");
}
}
else if(getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) {
addWeightSharingPointerField("uint32_t", "gp", backend.getDeviceVarPrefix() + "gp");
Expand Down
1 change: 0 additions & 1 deletion src/genn/genn/code_generator/modelSpecMerged.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa
{
return ((sg.getMatrixType() & SynapseMatrixConnectivity::SPARSE) &&
(sg.isWUVarInitRequired()
|| backend.isSynRemapRequired(sg)
|| (backend.isPostsynapticRemapRequired() && !sg.getWUModel()->getLearnPostCode().empty())));
},
&SynapseGroupInternal::getWUInitHashDigest);
Expand Down

0 comments on commit d941d60

Please sign in to comment.