Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Step 2 - Variable dimensions #598

Merged
merged 60 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
0d57043
define new access types based on dimensions
neworderofjamie Aug 16, 2023
6be6abd
started refactoring variable access -Made ``Models::Base::Var::accces…
neworderofjamie Aug 16, 2023
623719b
added accessors for inserting default
neworderofjamie Aug 17, 2023
deb4367
* Remove VarAccess and VarAccessDuplication
neworderofjamie Aug 17, 2023
f875d94
More iteration on this:
neworderofjamie Aug 18, 2023
cca56b7
sprinkled template around to appease GCC
neworderofjamie Aug 18, 2023
0b89d8e
added CustomUpdateVarAccess and more elegant isValid<> method to VarA…
neworderofjamie Aug 18, 2023
b413771
fixed typos
neworderofjamie Aug 18, 2023
4b43b8a
flipped test in CustomUpdateBase::isReduction to match new VarAccess …
neworderofjamie Aug 18, 2023
a4b872d
fixed typo
neworderofjamie Aug 18, 2023
7ba2674
correct test for custom update model validity
neworderofjamie Aug 18, 2023
1a60ad8
Renamed ``Models::checkVarReferences`` to ``Models::checkVarReference…
neworderofjamie Aug 18, 2023
fa8712b
fixed compiler errors
neworderofjamie Aug 29, 2023
e38d861
fixed unit tests
neworderofjamie Aug 29, 2023
492835f
removed GENN_EXPORT
neworderofjamie Aug 29, 2023
524e014
updated PyBind11 wrapper
neworderofjamie Aug 29, 2023
25e9dce
at least fixed PyGeNN interface
neworderofjamie Aug 29, 2023
a304015
update feature tests
neworderofjamie Aug 29, 2023
3874ed1
fixed warning in test
neworderofjamie Aug 29, 2023
a611308
fixed some more warnings
neworderofjamie Aug 29, 2023
8af2093
fixed typo in NeuronVarAccess::READ_ONLY and NeuronVarAccess::READ_ON…
neworderofjamie Aug 29, 2023
7047f0b
add clarifying comment to unit test
neworderofjamie Aug 29, 2023
8253564
fixed logic in ``VarReference::getDims`` and ``WUVarReference::getDim…
neworderofjamie Aug 29, 2023
b0fd670
fixed typo which meant death test was dying for incorrect reason!
neworderofjamie Aug 29, 2023
8a21a5f
Fixed a couple of issues in ``CustomUpdateBase::isReduction``
neworderofjamie Aug 29, 2023
da8f180
infact, we should ALWAYS call getDims directly on variable references…
neworderofjamie Aug 29, 2023
608c5e6
Correct resolution of dimensions in runner and initialisation
neworderofjamie Aug 30, 2023
42f0552
fixed some typos in custom update group merged
neworderofjamie Sep 1, 2023
e152bcf
extend VarAccess.get_custom_update_dims to do clearDim logic
neworderofjamie Sep 1, 2023
68d6bc3
corrected getSynapseVarSize logic and changed to take SynapseGroup
neworderofjamie Sep 1, 2023
3b26adf
updated variable shape-getting logic in PyGeNN
neworderofjamie Sep 1, 2023
23f1416
fix bug in environment
neworderofjamie Sep 1, 2023
6afbe5e
incorrect flags
neworderofjamie Sep 1, 2023
5dffc5c
PyBind11 wrapper for ``Var`` cannot implicitly convert ``NeuronVarAcc…
neworderofjamie Sep 1, 2023
426d100
removed more calls to getDims on var rather than var reference and ad…
neworderofjamie Sep 1, 2023
79cac2f
don't generate code for empty custom updates
neworderofjamie Sep 1, 2023
1a13a3a
infer batchedness from shape in _init_wum_var
neworderofjamie Sep 1, 2023
8efb89e
fixed dimension flags being used backwards
neworderofjamie Sep 1, 2023
600df8b
correctly resolve custom update var accesses in reductions
neworderofjamie Sep 1, 2023
62f0e05
fixed some variable access typos in test_custom_update
neworderofjamie Sep 1, 2023
9fce215
fixed typos in _init_wum_var
neworderofjamie Sep 1, 2023
995c7b8
typo in loading of custom connectivity update groups
neworderofjamie Sep 1, 2023
755d192
* switched models to use vectors of correct enum type
neworderofjamie Sep 5, 2023
d87a3af
updated adapters
neworderofjamie Sep 5, 2023
5b5efdb
moved generic model validation and hashing out of base class
neworderofjamie Sep 5, 2023
609386d
WIP var reference
neworderofjamie Sep 8, 2023
eae1e11
fixed up code base
neworderofjamie Sep 8, 2023
818b283
rename getAcccessDim to getVarAcccessDim for consistency
neworderofjamie Sep 8, 2023
e39eb98
updated SET_PRE_VARS and SET_POST_VARS macros
neworderofjamie Sep 8, 2023
0d40ab2
tests compile
neworderofjamie Sep 8, 2023
a48b5d4
fixed bug
neworderofjamie Sep 8, 2023
2d17150
fixed CUDA backend
neworderofjamie Sep 8, 2023
0011919
start updating GeNN wrapper
neworderofjamie Sep 8, 2023
6e74f6a
fix std::tieing of temporary variable
neworderofjamie Sep 18, 2023
2c2fe38
fixed wrapping of model classes
neworderofjamie Sep 18, 2023
c6ab9df
fixed typos
neworderofjamie Sep 18, 2023
2fbf52f
add additional test to custom updates to check whether REDUCE access …
neworderofjamie Sep 26, 2023
1410f87
simplify variable dimensions:
neworderofjamie Oct 31, 2023
14dca3a
update PygeNn to reflect simplified design
neworderofjamie Oct 31, 2023
d3d12c9
moved one macro around
neworderofjamie Oct 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/genn/backends/cuda/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ class BACKEND_EXPORT Backend : public BackendSIMT
[this, v](const auto &g, size_t)
{
const auto varRef = g.getVarReferences().at(v.name);
return getDeviceVarPrefix() + varRef.getVar().name + varRef.getTargetName(); ;
return getDeviceVarPrefix() + varRef.getVarName() + varRef.getTargetName(); ;
});

// Add NCCL reduction
Expand Down
3 changes: 2 additions & 1 deletion include/genn/backends/single_threaded_cpu/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ class BACKEND_EXPORT Backend : public BackendBase
// If variable is a reduction target, copy value from register straight back into global memory
if(v.access & VarAccessModeAttribute::REDUCE) {
const std::string idx = env.getName(idxName);
env.getStream() << "group->" << v.name << "[" << cg.getVarIndex(getVarAccessDuplication(v.access), idx) << "] = " << env[v.name] << ";" << std::endl;
const VarAccessDim varAccessDim = getVarAccessDim(v.access, cg.getArchetype().getDims());
env.getStream() << "group->" << v.name << "[" << cg.getVarIndex(1, varAccessDim, idx) << "] = " << env[v.name] << ";" << std::endl;
}
}

Expand Down
26 changes: 15 additions & 11 deletions include/genn/genn/code_generator/backendBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,16 +495,16 @@ class GENN_EXPORT BackendBase
void buildStandardEnvironment(EnvironmentGroupMergedField<PostsynapticUpdateGroupMerged> &env, unsigned int batchSize) const;
void buildStandardEnvironment(EnvironmentGroupMergedField<SynapseDynamicsGroupMerged> &env, unsigned int batchSize) const;
void buildStandardEnvironment(EnvironmentGroupMergedField<SynapseDendriticDelayUpdateGroupMerged> &env, unsigned int batchSize) const;
void buildStandardEnvironment(EnvironmentGroupMergedField<CustomUpdateGroupMerged> &env) const;
void buildStandardEnvironment(EnvironmentGroupMergedField<CustomUpdateWUGroupMerged> &env) const;
void buildStandardEnvironment(EnvironmentGroupMergedField<CustomUpdateTransposeWUGroupMerged> &env) const;
void buildStandardEnvironment(EnvironmentGroupMergedField<CustomUpdateGroupMerged> &env, unsigned int batchSize) const;
void buildStandardEnvironment(EnvironmentGroupMergedField<CustomUpdateWUGroupMerged> &env, unsigned int batchSize) const;
void buildStandardEnvironment(EnvironmentGroupMergedField<CustomUpdateTransposeWUGroupMerged> &env, unsigned int batchSize) const;
void buildStandardEnvironment(EnvironmentGroupMergedField<CustomConnectivityUpdateGroupMerged> &env) const;

void buildStandardEnvironment(EnvironmentGroupMergedField<NeuronInitGroupMerged> &env, unsigned int batchSize) const;
void buildStandardEnvironment(EnvironmentGroupMergedField<SynapseInitGroupMerged> &env, unsigned int batchSize) const;
void buildStandardEnvironment(EnvironmentGroupMergedField<CustomUpdateInitGroupMerged> &env) const;
void buildStandardEnvironment(EnvironmentGroupMergedField<CustomWUUpdateInitGroupMerged> &env) const;
void buildStandardEnvironment(EnvironmentGroupMergedField<CustomWUUpdateSparseInitGroupMerged> &env) const;
void buildStandardEnvironment(EnvironmentGroupMergedField<CustomUpdateInitGroupMerged> &env, unsigned int batchSize) const;
void buildStandardEnvironment(EnvironmentGroupMergedField<CustomWUUpdateInitGroupMerged> &env, unsigned int batchSize) const;
void buildStandardEnvironment(EnvironmentGroupMergedField<CustomWUUpdateSparseInitGroupMerged> &env, unsigned int batchSize) const;
void buildStandardEnvironment(EnvironmentGroupMergedField<CustomConnectivityUpdatePreInitGroupMerged> &env) const;
void buildStandardEnvironment(EnvironmentGroupMergedField<CustomConnectivityUpdatePostInitGroupMerged> &env) const;
void buildStandardEnvironment(EnvironmentGroupMergedField<SynapseSparseInitGroupMerged> &env, unsigned int batchSize) const;
Expand Down Expand Up @@ -550,18 +550,21 @@ class GENN_EXPORT BackendBase

//! Helper function to generate initialisation code for any reduction operations carried out be custom update group.
//! Returns vector of ReductionTarget structs, providing all information to write back reduction results to memory
std::vector<ReductionTarget> genInitReductionTargets(CodeStream &os, const CustomUpdateGroupMerged &cg, const std::string &idx = "") const;
std::vector<ReductionTarget> genInitReductionTargets(CodeStream &os, const CustomUpdateGroupMerged &cg,
unsigned int batchSize, const std::string &idx = "") const;

//! Helper function to generate initialisation code for any reduction operations carried out be custom weight update group.
//! //! Returns vector of ReductionTarget structs, providing all information to write back reduction results to memory
std::vector<ReductionTarget> genInitReductionTargets(CodeStream &os, const CustomUpdateWUGroupMerged &cg, const std::string &idx = "") const;
std::vector<ReductionTarget> genInitReductionTargets(CodeStream &os, const CustomUpdateWUGroupMerged &cg,
unsigned int batchSize, const std::string &idx = "") const;

private:
//--------------------------------------------------------------------------
// Private API
//--------------------------------------------------------------------------
template<typename G, typename R>
std::vector<ReductionTarget> genInitReductionTargets(CodeStream &os, const G &cg, const std::string &idx, R getVarRefIndexFn) const
template<typename A, typename G, typename R>
std::vector<ReductionTarget> genInitReductionTargets(CodeStream &os, const G &cg, unsigned int batchSize,
const std::string &idx, R getVarRefIndexFn) const
{
// Loop through variables
std::vector<ReductionTarget> reductionTargets;
Expand All @@ -571,8 +574,9 @@ class GENN_EXPORT BackendBase
if (v.access & VarAccessModeAttribute::REDUCE) {
const auto resolvedType = v.type.resolve(cg.getTypeContext());
os << resolvedType.getName() << " _lr" << v.name << " = " << getReductionInitialValue(getVarAccessMode(v.access), resolvedType) << ";" << std::endl;
const VarAccessDim varAccessDim = getVarAccessDim(v.access, cg.getArchetype().getDims());
reductionTargets.push_back({v.name, resolvedType, getVarAccessMode(v.access),
cg.getVarIndex(getVarAccessDuplication(v.access), idx)});
cg.getVarIndex(batchSize, varAccessDim, idx)});
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public GroupMerged<Custo
[arrayPrefix, v](const auto &g, size_t)
{
const auto varRef = A(g).getInitialisers().at(v.name);
return arrayPrefix + varRef.getVar().name + varRef.getTargetName();
return arrayPrefix + varRef.getVarName() + varRef.getTargetName();
});
}
}
Expand All @@ -94,7 +94,7 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public GroupMerged<Custo
for(const auto &v : archetypeAdaptor.getDefs()) {
// If model isn't batched or variable isn't duplicated
const auto &varRef = archetypeAdaptor.getInitialisers().at(v.name);
if(batchSize == 1 || !varRef.isDuplicated()) {
if(batchSize == 1 || !(varRef.getVarDims() & VarAccessDim::BATCH)) {
// Add field with qualified type which indexes private pointer field
const auto resolvedType = v.type.resolve(getTypeContext());
const auto qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType;
Expand Down
12 changes: 6 additions & 6 deletions include/genn/genn/code_generator/customUpdateGroupMerged.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ class GENN_EXPORT CustomUpdateGroupMerged : public GroupMerged<CustomUpdateInter
runnerVarDecl, runnerMergedStructAlloc, name);
}

void generateCustomUpdate(const BackendBase &backend, EnvironmentExternalBase &env,
void generateCustomUpdate(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize,
BackendBase::GroupHandlerEnv<CustomUpdateGroupMerged> genPostamble);

std::string getVarIndex(VarAccessDuplication varDuplication, const std::string &index) const;
std::string getVarRefIndex(bool delay, VarAccessDuplication varDuplication, const std::string &index) const;
std::string getVarIndex(unsigned int batchSize, VarAccessDim varDims, const std::string &index) const;
std::string getVarRefIndex(bool delay, unsigned int batchSize, VarAccessDim varDims, const std::string &index) const;

//----------------------------------------------------------------------------
// Static constants
Expand Down Expand Up @@ -64,11 +64,11 @@ class GENN_EXPORT CustomUpdateWUGroupMergedBase : public GroupMerged<CustomUpdat

boost::uuids::detail::sha1::digest_type getHashDigest() const;

void generateCustomUpdate(const BackendBase &backend, EnvironmentExternalBase &env,
void generateCustomUpdate(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize,
BackendBase::GroupHandlerEnv<CustomUpdateWUGroupMergedBase> genPostamble);

std::string getVarIndex(VarAccessDuplication varDuplication, const std::string &index) const;
std::string getVarRefIndex(VarAccessDuplication varDuplication, const std::string &index) const;
std::string getVarIndex(unsigned int batchSize, VarAccessDim varDims, const std::string &index) const;
std::string getVarRefIndex(unsigned int batchSize, VarAccessDim varDims, const std::string &index) const;

};

Expand Down
Loading