Skip to content

Commit

Permalink
Closes #3336, #3362: Reuse random number generation loop structure (#…
Browse files Browse the repository at this point in the history
…3352)

* first pass

* closes #3362 fix bug in thisLocsNumChunks calculation

---------

Co-authored-by: Tess Hayes <[email protected]>
  • Loading branch information
stress-tess and stress-tess authored Jun 26, 2024
1 parent 3c8744d commit 7b51173
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 115 deletions.
136 changes: 21 additions & 115 deletions src/RandMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ module RandMsg
use Logging;
use Message;
use RandArray;
use RandUtil;
use CommAggregation;

use MultiTypeSymbolTable;
Expand Down Expand Up @@ -547,19 +548,34 @@ module RandMsg
return new MsgTuple(repMsg, MsgType.NORMAL);
}

inline proc poissonGenerator(lam: real, ref rs) {
// the algorithm from knuth found here:
// https://en.wikipedia.org/wiki/Poisson_distribution#Random_variate_generation
// generates values drawn from poisson distribution using a stream of uniformly distributed random numbers
var L = exp(-lam);
var k = 0;
var p = 1.0;

do {
k += 1;
p = p * rs.next(0, 1);
} while p > L;
return k - 1;
}

proc poissonGeneratorMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
const pn = Reflection.getRoutineName(),
name = msgArgs.getValueOf("name"), // generator name
isSingleLam = msgArgs.get("is_single_lambda").getBoolValue(), // boolean indicating if lambda is a single value or array
lamStr = msgArgs.getValueOf("lam"), // lambda for poisson distribution
size = msgArgs.get("size").getIntValue(), // number of values to be generated
hasSeed = msgArgs.get("has_seed").getBoolValue(), // boolean indicating if the generator has a seed
hasSeed = msgArgs.get("has_seed").getBoolValue(), // boolean indicating if the generator has a seed
state = msgArgs.get("state").getIntValue(), // rng state
rname = st.nextName();


randLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),
"name: %? size %i isSingleLam %? lamStr %? state %i".doFormat(name, size, isSingleLam, lamStr, state));
"name: %? size %i hasSeed %? isSingleLam %? lamStr %? state %i".doFormat(name, size, hasSeed, isSingleLam, lamStr, state));

st.checkTable(name);

Expand All @@ -569,122 +585,12 @@ module RandMsg
// you have to skip to one before where you want to be
rng.skipTo(state-1);
}

// uses the algorithm from knuth found here:
// https://en.wikipedia.org/wiki/Poisson_distribution#Random_variate_generation
// generates values drawn from poisson distribution using a stream of uniformly distributed random numbers
var poissonArr = makeDistArray(size, int);
const lam = new scalarOrArray(lamStr, !isSingleLam, st);

if hasSeed {
// use a fixed number of elements per stream instead of relying on number of locales or numTasksPerLoc because these
// can vary from run to run / machine to mahchine. And it's important for the same seed to give the same results
const generatorSeed = (rng.next() * 2**62):int,
minPerStream = 256,
elemsPerStream = max(minPerStream, 2**(2 * ceil(log10(size)):int));

if isSingleLam {
const lam = lamStr:real;
// using nested coforalls over locales and tasks so we know how to generate taskSeed
coforall loc in Locales do on loc {
const locSubDom = poissonArr.localSubdomain(),
offset = if loc.id == 0 then 0 else elemsPerStream - (locSubDom.low % elemsPerStream);

// skip if all the values were pulled to previous locale
if offset <= locSubDom.high {
// we take the ceil in chunk calculation because if elemsPerStream doesn't evenly divide along locale boundaries, the remainder is pulled to the previous locale
const chunksAlreadyDone = if loc.id == 0 then 0 else ceil((locSubDom.low + 1) / elemsPerStream:real):int, // number of chunks handled by previous locales
thisLocsNumChunks = ceil((locSubDom.high + 1 - locSubDom.low + offset) / elemsPerStream:real):int; // number of chunks this locale needs to handle

coforall streamID in 0..<thisLocsNumChunks {
const taskSeed = generatorSeed + chunksAlreadyDone + streamID, // initial seed offset by other locales threads plus current thread id
startIdx = (streamID * elemsPerStream) + locSubDom.low + offset,
stopIdx = min(startIdx + elemsPerStream - 1, poissonArr.domain.high); // continue past end of localSubDomain to read full block to avoid seed sharing
var rs = new randomStream(real, taskSeed);
for i in startIdx..stopIdx {
var L = exp(-lam);
var k = 0;
var p = 1.0;

do {
k += 1;
p = p * rs.next(0, 1);
} while p > L;
poissonArr[i] = k - 1;
}
}
}
}
}
else {
st.checkTable(lamStr);
const lamArr = toSymEntry(getGenericTypedArrayEntry(lamStr, st),real).a;

// using nested coforalls over locales and tasks so we know how to generate taskSeed
coforall loc in Locales do on loc {
const locSubDom = poissonArr.localSubdomain(),
offset = if loc.id == 0 then 0 else elemsPerStream - (locSubDom.low % elemsPerStream);

// skip if all the values were pulled to previous locale
if offset <= locSubDom.high {
// we take the ceil in chunk calculation because if elemsPerStream doesn't evenly divide along locale boundaries, the remainder is pulled to the previous locale
const chunksAlreadyDone = if loc.id == 0 then 0 else ceil((locSubDom.low + 1) / elemsPerStream:real):int, // number of chunks handled by previous locales
thisLocsNumChunks = ceil((locSubDom.high + 1 - locSubDom.low + offset) / elemsPerStream:real):int; // number of chunks this locale needs to handle

coforall streamID in 0..<thisLocsNumChunks {
const taskSeed = generatorSeed + chunksAlreadyDone + streamID, // initial seed offset by other locales threads plus current thread id
startIdx = (streamID * elemsPerStream) + locSubDom.low + offset,
stopIdx = min(startIdx + elemsPerStream - 1, poissonArr.domain.high); // continue past end of localSubDomain to read full block to avoid seed sharing

var rs = new randomStream(real, taskSeed);
for i in startIdx..stopIdx {
const lam = lamArr[locSubDom.low + i];
var L = exp(-lam);
var k = 0;
var p = 1.0;

do {
k += 1;
p = p * rs.next(0, 1);
} while p > L;
poissonArr[i] = k - 1;
}
}
}
}
}
}
else { // non-seeded case, we can just use task private variables for our random streams
if isSingleLam {
const lam = lamStr:real;
forall pv in poissonArr with (var rs = new randomStream(real)) {
var L = exp(-lam);
var k = 0;
var p = 1.0;

do {
k += 1;
p = p * rs.next(0, 1);
} while p > L;
pv = k - 1;
}
}
else {
st.checkTable(lamStr);
const lamArr = toSymEntry(getGenericTypedArrayEntry(lamStr, st),real).a;
forall (pv, lam) in zip(poissonArr, lamArr) with (var rs = new randomStream(real)) {
var L = exp(-lam);
var k = 0;
var p = 1.0;

do {
k += 1;
p = p * rs.next(0, 1);
} while p > L;
pv = k - 1;
}
}
}
uniformStreamPerElem(poissonArr, GenerationFunction.PoissonGenerator, hasSeed, lam, rng);
st.addEntry(rname, createSymEntry(poissonArr));

const repMsg = "created " + st.attrib(rname);
randLogger.debug(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
Expand Down
90 changes: 90 additions & 0 deletions src/RandUtil.chpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
module RandUtil {
use MultiTypeSymbolTable;
use MultiTypeSymEntry;
use RandMsg;
use ArkoudaRandomCompat;

const minPerStream = 256; // minimum number of elements per random stream

record scalarOrArray {
var isArray: bool;
var sym; // TODO figure out type hint here to avoid generic
var val: real;

proc init(scalarOrArrayString: string, isArray: bool, st: borrowed SymTab) {
// I'm not sure if there's a good way to remove these try!
this.isArray = isArray;
if isArray {
try! st.checkTable(scalarOrArrayString);
this.sym = try! toSymEntry(getGenericTypedArrayEntry(scalarOrArrayString, st),real).a;
}
else {
// prob not the smartest way of doing this
// we just want to avoid unnecessarily creating a large array
this.sym = try! makeDistArray([0.0]);
val = try! scalarOrArrayString:real;
}
}

proc this(idx): real {
return if isArray then this.sym[idx] else this.val;
}
}

enum GenerationFunction {
PoissonGenerator,
}

proc uniformStreamPerElem(ref randArr: [?D] ?t, param function: GenerationFunction, hasSeed: bool, const lam: scalarOrArray(?), ref rng) throws {
if hasSeed {
// use a fixed number of elements per stream instead of relying on number of locales or numTasksPerLoc because these
// can vary from run to run / machine to mahchine. And it's important for the same seed to give the same results
const generatorSeed = (rng.next() * 2**62):int,
elemsPerStream = max(minPerStream, 2**(2 * ceil(log10(D.size)):int));

// using nested coforalls over locales and tasks so we know how to generate taskSeed
coforall loc in Locales do on loc {
const locSubDom = randArr.localSubdomain(),
offset = if loc.id == 0 then 0 else elemsPerStream - (locSubDom.low % elemsPerStream);

// skip if all the values were pulled to previous locale
if offset <= locSubDom.high {
// we take the ceil in chunk calculation because if elemsPerStream doesn't evenly divide along locale boundaries, the remainder is pulled to the previous locale
const chunksAlreadyDone = if loc.id == 0 then 0 else ceil((locSubDom.low + 1) / elemsPerStream:real):int, // number of chunks handled by previous locales
thisLocsNumChunks = ceil((locSubDom.high + 1 - (locSubDom.low + offset)) / elemsPerStream:real):int; // number of chunks this locale needs to handle

coforall streamID in 0..<thisLocsNumChunks {
const taskSeed = generatorSeed + chunksAlreadyDone + streamID, // initial seed offset by other locales threads plus current thread id
startIdx = (streamID * elemsPerStream) + locSubDom.low + offset,
stopIdx = min(startIdx + elemsPerStream - 1, randArr.domain.high); // continue past end of localSubDomain to read full block to avoid seed sharing

var rs = new randomStream(real, taskSeed);
for i in startIdx..stopIdx {
select function {
// TODO look into adding copy aggregation looking here
when GenerationFunction.PoissonGenerator {
randArr[i] = poissonGenerator(lam[i], rs);
}
otherwise {
compilerError("Unrecognized generation function");
}
}
}
}
}
}
}
else { // non-seeded case, we can just use task private variables for our random streams
forall (rv, i) in zip(randArr, randArr.domain) with (var rs = new randomStream(real)) {
select function {
when GenerationFunction.PoissonGenerator {
rv = poissonGenerator(lam[i], rs);
}
otherwise {
compilerError("Unrecognized generation function");
}
}
}
}
}
}

0 comments on commit 7b51173

Please sign in to comment.