Skip to content

Commit

Permalink
Merge pull request #498 from genn-team/binomial_distribution
Browse files Browse the repository at this point in the history
Binomial distribution
  • Loading branch information
neworderofjamie authored Jan 25, 2022
2 parents ad8179d + df71783 commit acad848
Show file tree
Hide file tree
Showing 8 changed files with 335 additions and 33 deletions.
25 changes: 22 additions & 3 deletions include/genn/genn/initVarSnippet.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,11 @@ class Exponential : public Base
//----------------------------------------------------------------------------
// InitVarSnippet::Gamma
//----------------------------------------------------------------------------
//! Initialises variable by sampling from the exponential distribution
/*! This snippet takes 1 parameter:
//! Initialises variable by sampling from the gamma distribution
/*! This snippet takes s parameters:
*
- \c lambda - mean event rate (events per unit time/distance)*/
- \c a - distribution shape
- \c b - distribution scale*/
class Gamma : public Base
{
public:
Expand All @@ -210,4 +211,22 @@ class Gamma : public Base

SET_PARAM_NAMES({"a", "b"});
};

//----------------------------------------------------------------------------
// InitVarSnippet::Binomial
//----------------------------------------------------------------------------
//! Initialises variable by sampling from the binomial distribution
/*! This snippet takes 2 parameters:
*
- \c n - number of trials
- \c p - success probability for each trial*/
class Binomial : public Base
{
public:
DECLARE_SNIPPET(InitVarSnippet::Binomial, 2);

SET_CODE("$(value) = $(gennrand_binomial, (unsigned int)$(n), $(p));");

SET_PARAM_NAMES({"n", "p"});
};
} // namespace InitVarSnippet
112 changes: 110 additions & 2 deletions src/genn/backends/cuda/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,17 @@ const std::vector<Substitutions::FunctionTemplate> cudaSinglePrecisionFunctions
{"gennrand_normal", 0, "curand_normal($(rng))"},
{"gennrand_exponential", 0, "exponentialDistFloat($(rng))"},
{"gennrand_log_normal", 2, "curand_log_normal_float($(rng), $(0), $(1))"},
{"gennrand_gamma", 1, "gammaDistFloat($(rng), $(0))"}
{"gennrand_gamma", 1, "gammaDistFloat($(rng), $(0))"},
{"gennrand_binomial", 2, "binomialDistFloat($(rng), $(0), $(1))"}
};
//--------------------------------------------------------------------------
const std::vector<Substitutions::FunctionTemplate> cudaDoublePrecisionFunctions = {
{"gennrand_uniform", 0, "curand_uniform_double($(rng))"},
{"gennrand_normal", 0, "curand_normal_double($(rng))"},
{"gennrand_exponential", 0, "exponentialDistDouble($(rng))"},
{"gennrand_log_normal", 2, "curand_log_normal_double($(rng), $(0), $(1))"},
{"gennrand_gamma", 1, "gammaDistDouble($(rng), $(0))"}
{"gennrand_gamma", 1, "gammaDistDouble($(rng), $(0))"},
{"gennrand_binomial", 2, "binomialDistDouble($(rng), $(0), $(1))"}
};
//--------------------------------------------------------------------------
// Timer
Expand Down Expand Up @@ -1288,6 +1290,112 @@ void Backend::genDefinitionsInternalPreamble(CodeStream &os, const ModelSpecMerg
}
}
os << std::endl;

// The following code is an almost exact copy of numpy's
// rk_binomial_inversion function (numpy/random/mtrand/distributions.c)
os << "template<typename RNG>" << std::endl;
os << "__device__ inline unsigned int binomialDistFloatInternal(RNG *rng, unsigned int n, float p)" << std::endl;
{
CodeStream::Scope b(os);
os << "const float q = 1.0f - p;" << std::endl;
os << "const float qn = expf(n * logf(q));" << std::endl;
os << "const float np = n * p;" << std::endl;
os << "const unsigned int bound = min(n, (unsigned int)(np + (10.0f * sqrtf((np * q) + 1.0f))));" << std::endl;

os << "unsigned int x = 0;" << std::endl;
os << "float px = qn;" << std::endl;
os << "float u = curand_uniform(rng);" << std::endl;
os << "while(u > px)" << std::endl;
{
CodeStream::Scope b(os);
os << "x++;" << std::endl;
os << "if(x > bound)";
{
CodeStream::Scope b(os);
os << "x = 0;" << std::endl;
os << "px = qn;" << std::endl;
os << "u = curand_uniform(rng);" << std::endl;
}
os << "else";
{
CodeStream::Scope b(os);
os << "u -= px;" << std::endl;
os << "px = ((n - x + 1) * p * px) / (x * q);" << std::endl;
}
}
os << "return x;" << std::endl;
}
os << std::endl;

os << "template<typename RNG>" << std::endl;
os << "__device__ inline unsigned int binomialDistFloat(RNG *rng, unsigned int n, float p)" << std::endl;
{
CodeStream::Scope b(os);
os << "if(p <= 0.5f)";
{
CodeStream::Scope b(os);
os << "return binomialDistFloatInternal(rng, n, p);" << std::endl;

}
os << "else";
{
CodeStream::Scope b(os);
os << "return (n - binomialDistFloatInternal(rng, n, 1.0f - p));" << std::endl;
}
}

// The following code is an almost exact copy of numpy's
// rk_binomial_inversion function (numpy/random/mtrand/distributions.c)
os << "template<typename RNG>" << std::endl;
os << "__device__ inline unsigned int binomialDistDoubleInternal(RNG *rng, unsigned int n, double p)" << std::endl;
{
CodeStream::Scope b(os);
os << "const double q = 1.0 - p;" << std::endl;
os << "const double qn = exp(n * log(q));" << std::endl;
os << "const double np = n * p;" << std::endl;
os << "const unsigned int bound = min(n, (unsigned int)(np + (10.0 * sqrt((np * q) + 1.0))));" << std::endl;

os << "unsigned int x = 0;" << std::endl;
os << "double px = qn;" << std::endl;
os << "double u = curand_uniform_double(rng);" << std::endl;
os << "while(u > px)" << std::endl;
{
CodeStream::Scope b(os);
os << "x++;" << std::endl;
os << "if(x > bound)";
{
CodeStream::Scope b(os);
os << "x = 0;" << std::endl;
os << "px = qn;" << std::endl;
os << "u = curand_uniform_double(rng);" << std::endl;
}
os << "else";
{
CodeStream::Scope b(os);
os << "u -= px;" << std::endl;
os << "px = ((n - x + 1) * p * px) / (x * q);" << std::endl;
}
}
os << "return x;" << std::endl;
}
os << std::endl;

os << "template<typename RNG>" << std::endl;
os << "__device__ inline unsigned int binomialDistDouble(RNG *rng, unsigned int n, double p)" << std::endl;
{
CodeStream::Scope b(os);
os << "if(p <= 0.5)";
{
CodeStream::Scope b(os);
os << "return binomialDistDoubleInternal(rng, n, p);" << std::endl;

}
os << "else";
{
CodeStream::Scope b(os);
os << "return (n - binomialDistDoubleInternal(rng, n, 1.0 - p));" << std::endl;
}
}
}
//--------------------------------------------------------------------------
void Backend::genRunnerPreamble(CodeStream &os, const ModelSpecMerged&, const MemAlloc&) const
Expand Down
57 changes: 55 additions & 2 deletions src/genn/backends/opencl/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,17 @@ const std::vector<Substitutions::FunctionTemplate> openclLFSRFunctions = {
{"gennrand_normal", 0, "normalDistLfsr113($(rng))"},
{"gennrand_exponential", 0, "exponentialDistLfsr113($(rng))"},
{"gennrand_log_normal", 2, "logNormalDistLfsr113($(rng), $(0), $(1))"},
{"gennrand_gamma", 1, "gammaDistLfsr113($(rng), $(0))"}
{"gennrand_gamma", 1, "gammaDistLfsr113($(rng), $(0))"},
{"gennrand_binomial", 2, "binomialDistLfsr113($(rng), $(0), $(1))"}
};
//-----------------------------------------------------------------------
const std::vector<Substitutions::FunctionTemplate> openclPhilloxFunctions = {
{"gennrand_uniform", 0, "clrngPhilox432RandomU01($(rng))"},
{"gennrand_normal", 0, "normalDistPhilox432($(rng))"},
{"gennrand_exponential", 0, "exponentialDistPhilox432($(rng))"},
{"gennrand_log_normal", 2, "logNormalDistPhilox432($(rng), $(0), $(1))"},
{"gennrand_gamma", 1, "gammaDistPhilox432($(rng), $(0))"}
{"gennrand_gamma", 1, "gammaDistPhilox432($(rng), $(0))"},
{"gennrand_binomial", 2, "binomialDistPhilox432($(rng), $(0), $(1))"}
};
//--------------------------------------------------------------------------
template<typename T>
Expand Down Expand Up @@ -2644,6 +2646,57 @@ void Backend::genKernelPreamble(CodeStream &os, const ModelSpecMerged &modelMerg
os << "return gammaDistInternal" << r << "(rng, c, d);" << std::endl;
}
}

// The following code is an almost exact copy of numpy's
// rk_binomial_inversion function (numpy/random/mtrand/distributions.c)
os << "inline unsigned int binomialDist" << r << "Internal(clrng" << r << "Stream *rng, unsigned int n, " << precision << " p)" << std::endl;
{
CodeStream::Scope b(os);
os << "const " << precision << " q = " << model.scalarExpr(1.0) << " - p;" << std::endl;
os << "const " << precision << " qn = exp(n * log(q));" << std::endl;
os << "const " << precision << " np = n * p;" << std::endl;
os << "const unsigned int bound = min(n, (unsigned int)(np + (" << model.scalarExpr(10.0) << " * sqrt((np * q) + " << model.scalarExpr(1.0) << "))));" << std::endl;

os << "unsigned int x = 0;" << std::endl;
os << precision << " px = qn;" << std::endl;
os << precision << " u = clrng" << r << "RandomU01(rng);" << std::endl;
os << "while(u > px)" << std::endl;
{
CodeStream::Scope b(os);
os << "x++;" << std::endl;
os << "if(x > bound)";
{
CodeStream::Scope b(os);
os << "x = 0;" << std::endl;
os << "px = qn;" << std::endl;
os << "u = clrng" << r << "RandomU01(rng);" << std::endl;
}
os << "else";
{
CodeStream::Scope b(os);
os << "u -= px;" << std::endl;
os << "px = ((n - x + 1) * p * px) / (x * q);" << std::endl;
}
}
os << "return x;" << std::endl;
}
os << std::endl;

os << "inline unsigned int binomialDist" << r << "(clrng" << r << "Stream *rng, unsigned int n, " << precision << " p)" << std::endl;
{
CodeStream::Scope b(os);
os << "if(p <= " << model.scalarExpr(0.5) << ")";
{
CodeStream::Scope b(os);
os << "return binomialDist" << r << "Internal(rng, n, p);" << std::endl;

}
os << "else";
{
CodeStream::Scope b(os);
os << "return (n - binomialDist" << r << "Internal(rng, n, " << model.scalarExpr(1.0) << " - p));" << std::endl;
}
}
os << std::endl;
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/genn/backends/single_threaded_cpu/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@ const std::vector<Substitutions::FunctionTemplate> cpuSinglePrecisionFunctions =
{"gennrand_normal", 0, "standardNormalDistribution($(rng))"},
{"gennrand_exponential", 0, "standardExponentialDistribution($(rng))"},
{"gennrand_log_normal", 2, "std::lognormal_distribution<float>($(0), $(1))($(rng))"},
{"gennrand_gamma", 1, "std::gamma_distribution<float>($(0), 1.0f)($(rng))"}
{"gennrand_gamma", 1, "std::gamma_distribution<float>($(0), 1.0f)($(rng))"},
{"gennrand_binomial", 2, "std::binomial_distribution<unsigned int>($(0), $(1))($(rng))"}
};
//--------------------------------------------------------------------------
const std::vector<Substitutions::FunctionTemplate> cpuDoublePrecisionFunctions = {
{"gennrand_uniform", 0, "standardUniformDistribution($(rng))"},
{"gennrand_normal", 0, "standardNormalDistribution($(rng))"},
{"gennrand_exponential", 0, "standardExponentialDistribution($(rng))"},
{"gennrand_log_normal", 2, "std::lognormal_distribution<double>($(0), $(1))($(rng))"},
{"gennrand_gamma", 1, "std::gamma_distribution<double>($(0), 1.0)($(rng))"}
{"gennrand_gamma", 1, "std::gamma_distribution<double>($(0), 1.0)($(rng))"},
{"gennrand_binomial", 2, "std::binomial_distribution<unsigned int>($(0), $(1))($(rng))"}
};

//--------------------------------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion src/genn/genn/gennUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ GenericFunction randomFuncs[] = {
{"gennrand_normal", 0},
{"gennrand_exponential", 0},
{"gennrand_log_normal", 2},
{"gennrand_gamma", 1}
{"gennrand_gamma", 1},
{"gennrand_binomial", 2}
};
}

Expand Down
1 change: 1 addition & 0 deletions src/genn/genn/initVarSnippet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ IMPLEMENT_SNIPPET(InitVarSnippet::NormalClipped);
IMPLEMENT_SNIPPET(InitVarSnippet::NormalClippedDelay);
IMPLEMENT_SNIPPET(InitVarSnippet::Exponential);
IMPLEMENT_SNIPPET(InitVarSnippet::Gamma);
IMPLEMENT_SNIPPET(InitVarSnippet::Binomial);

//----------------------------------------------------------------------------
// InitVarSnippet::Base
Expand Down
Loading

0 comments on commit acad848

Please sign in to comment.