Skip to content

Commit

Permalink
make Julia global RNG an instance of MersenneTwister
Browse files Browse the repository at this point in the history
  • Loading branch information
rfourquet committed Sep 17, 2014
1 parent 8f8e43d commit 7126d7f
Showing 1 changed file with 15 additions and 17 deletions.
32 changes: 15 additions & 17 deletions base/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ type MersenneTwister <: AbstractRNG
MersenneTwister() = MersenneTwister(0)
end

immutable GlobalRNG <: AbstractRNG
end

typealias MT Union(GlobalRNG, MersenneTwister)

# random numbers of the following types are implemented:
MTRealTypes = Union(Bool, Integer128, Float16, Float32, Float64)
Expand Down Expand Up @@ -100,8 +96,8 @@ srand() = srand(make_seed())
srand(r::MersenneTwister) = srand(r, make_seed())

function srand(seed::Union(Integer, Vector{Uint32}))
global RANDOM_SEED = make_seed(seed)
dsfmt_gv_init_by_array(RANDOM_SEED)
srand(GlobalRNG(), seed)
global RANDOM_SEED = GLOBAL_RNG.seed
end

function srand(r::MersenneTwister, seed::Union(Integer, Vector{Uint32}))
Expand All @@ -114,28 +110,31 @@ srand(filename::String, n::Integer=4) = srand(make_seed(filename, n))

srand(r::MersenneTwister, filename::String, n::Integer=4) = srand(r, make_seed(filename, n))

## Global RNG: needs srand to be instanciated

GLOBAL_RNG = MersenneTwister()
GlobalRNG() = GLOBAL_RNG

## random floating point values

rand(r::AbstractRNG) = rand(r, Float64)

# MersenneTwister
rand(r::GlobalRNG, ::Type{Float64}) = dsfmt_gv_genrand_close_open()
rand(r::MersenneTwister, ::Type{Float64}) = dsfmt_genrand_close_open(r.state)

rand{T<:Union(Float16, Float32)}(r::MT, ::Type{T}) = convert(T, rand(r))
rand{T<:Union(Float16, Float32)}(r::MersenneTwister, ::Type{T}) = convert(T, rand(r))

## random integers (MersenneTwister)

rand(::GlobalRNG, ::Type{Uint32}) = dsfmt_gv_genrand_uint32()
rand(r::MersenneTwister, ::Type{Uint32}) = dsfmt_genrand_uint32(r.state)

rand(r::MT, ::Type{Bool}) = bool( rand(r, Uint32) & 1)
rand(r::MT, ::Type{Uint8}) = uint8( rand(r, Uint32))
rand(r::MT, ::Type{Uint16}) = uint16( rand(r, Uint32))
rand(r::MT, ::Type{Uint64}) = uint64( rand(r, Uint32))<<32 | rand(r, Uint32)
rand(r::MT, ::Type{Uint128}) = uint128(rand(r, Uint64))<<64 | rand(r, Uint64)
rand(r::MersenneTwister, ::Type{Bool}) = bool( rand(r, Uint32) & 1)
rand(r::MersenneTwister, ::Type{Uint8}) = uint8( rand(r, Uint32))
rand(r::MersenneTwister, ::Type{Uint16}) = uint16( rand(r, Uint32))
rand(r::MersenneTwister, ::Type{Uint64}) = uint64( rand(r, Uint32))<<32 | rand(r, Uint32)
rand(r::MersenneTwister, ::Type{Uint128}) = uint128(rand(r, Uint64))<<64 | rand(r, Uint64)

rand{T<:Signed128}(r::MT, ::Type{T}) = convert(T, rand(r, as_unsigned(T)))
rand{T<:Signed128}(r::MersenneTwister, ::Type{T}) = convert(T, rand(r, as_unsigned(T)))

## random complex values (AbstractRNG)

Expand Down Expand Up @@ -751,8 +750,7 @@ ziggurat_nor_inv_r = inv(ziggurat_nor_r)
ziggurat_exp_r = 7.6971174701310497140446280481

rand(state::DSFMT_state) = dsfmt_genrand_close_open(state)
randi() = reinterpret(Uint64,dsfmt_gv_genrand_close1_open2()) & 0x000fffffffffffff
randi(state::DSFMT_state) = reinterpret(Uint64,dsfmt_genrand_close1_open2(state)) & 0x000fffffffffffff
randi(state::DSFMT_state=GLOBAL_RNG.state) = reinterpret(Uint64,dsfmt_genrand_close1_open2(state)) & 0x000fffffffffffff
for (lhs, rhs) in (([], []),
([:(state::DSFMT_state)], [:state]))
@eval begin
Expand Down

0 comments on commit 7126d7f

Please sign in to comment.