Skip to content

Commit

Permalink
random_wrap: update for future Julia versions
Browse files Browse the repository at this point in the history
  • Loading branch information
lukas-weber committed Oct 4, 2023
1 parent 4fc0dde commit 5f1080b
Showing 1 changed file with 29 additions and 9 deletions.
38 changes: 29 additions & 9 deletions src/random_wrap.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,41 @@
using Random
using HDF5

function write_checkpoint(rng::Random.Xoshiro, out::HDF5.Group)
function guess_xoshiro_version()
num_fields = length(fieldnames(Xoshiro))
if num_fields == 4
return 1
elseif num_fields == 5
return 2
end
error(
"Carlo wrapper does not support this version of Xoshiro yet. Please file a bug report",
)
end


function write_checkpoint(rng::Xoshiro, out::HDF5.Group)
out["type"] = "xoroshiro256++"
out["state"] = [rng.s0, rng.s1, rng.s2, rng.s3]
out["rng_version"] = 1
out["state"] = collect(getproperty.(rng, fieldnames(Xoshiro)))

out["rng_version"] = guess_xoshiro_version()

return nothing
end

function read_checkpoint(::Type{Random.Xoshiro}, in::HDF5.Group)
rng_type = in["type"]

if rng_type == "xoroshiro256++"
function read_checkpoint(::Type{Xoshiro}, in::HDF5.Group)
rng_type = read(in["type"])
if rng_type != "xoroshiro256++"
error("checkpoint was done with a different RNG: $(rng_type)")
end

state = in["state"]
return Random.Xoshiro(state[1], state[2], state[3], state[4])
rng_version = read(in["rng_version"])
if rng_version != guess_xoshiro_version()
error(
"checkpoint was done with a different version of Xoshiro. Try running with the version of Julia you used originally.",
)
end

state = read(in["state"])
return Random.Xoshiro(state...)
end

0 comments on commit 5f1080b

Please sign in to comment.