-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Restart and store trace #1031
Comments
yes
no, it's not and you can't really do it
you can only do it by running for a number of iterations and starting from that final point... sorry. Instead of using the trace system you can also store information in your objective function and control it that way. Then you could save it to a file and not worry about memory.. |
I hacked around to do this, in case it is helpful, but using some internals and serializing the using Optim, UUIDs, Serialization
mutable struct SaveStateWrapper{F,S}
const obj::F # Objective function
const optimstate::S
const filename::String
const num_calls_per_save::Int
num_calls_since_save::Int
end
function SaveStateWrapper(obj, optimstate; num_calls_per_save)
filename = string(uuid1())*".state"
@info "Creating SaveStateWrapper with filename", filename
return SaveStateWrapper(obj, optimstate, filename, num_calls_per_save, 0)
end
function save_optim_state(filename, state)
tmp_file = filename*"_tmp"
isfile(tmp_file) && rm(tmp_file)
isfile(filename) && mv(filename, tmp_file)
serialize(filename, state)
isfile(tmp_file) && rm(tmp_file)
end
function (ssw::SaveStateWrapper)(args...; kwargs...)
ssw.num_calls_since_save += 1
if ssw.num_calls_since_save > ssw.num_calls_per_save
save_optim_state(ssw.filename, ssw.optimstate)
ssw.num_calls_since_save = 0
end
ssw.obj(args...; kwargs...)
end
function optimize_with_restart(obj, x0, method, options;
inplace = true, autodiff = :finite, # Optim settings
num_calls_per_save=10, # wrapper settings
state=nothing
)
if state===nothing
the_state = Optim.initial_state(method, options, Optim.promote_objtype(method, x0, autodiff, inplace, obj), x0)
else
the_state = state
end
wrapped_obj = SaveStateWrapper(obj, the_state; num_calls_per_save)
real_obj = Optim.promote_objtype(method, x0, autodiff, inplace, wrapped_obj)
return Optim.optimize(real_obj, x0, method, options, the_state)
end The following test shows that it works. Must first get the output from the first run, and then save the # Create a special objective function around `sum(x)`, that will
# 1) throw an error when I want to
# 2) record the history of objective values.
struct MyObj
vals::Vector{Float64}
fail_at::Int
end
MyObj(;fail_at=10) = MyObj(Float64[], fail_at)
function (m::MyObj)(x)
o = sum(x)
push!(m.vals, o)
length(m.vals) >= m.fail_at && error("Planned failure")
return o
end
# o1 will fail after 10 function calls
o1 = MyObj()
try
r = optimize_with_restart(o1, ones(4), NelderMead(), Optim.Options(); num_calls_per_save=1)
catch e
println(e) # Simulate failure
end
# Change to the path outputted during the creation of SaveStateWrapper
state_file = "de16fbc0-352e-11ee-3458-2b759117f9c2.state"
# o2 will fail after 20 calls, but should restart from about where o1 left off
state = deserialize(state_file)
o2 = MyObj(;fail_at=20)
try
r2 = optimize_with_restart(o2, ones(4), NelderMead(), Optim.Options(); num_calls_per_save=1, state=state)
catch e
println(e) # Simulate failure
end
offset = 3 # Not sure this isn't zero or 1...
restarted_trace = append!(copy(o1.vals), o2.vals[(1+offset):end])
# o3 runs from beginning without interruption (up to 30 function calls)
o3 = MyObj(;fail_at=30)
try
r3 = optimize_with_restart(o3, ones(4), NelderMead(), Optim.Options(); num_calls_per_save=1)
catch e
println(e) # Simulate failure
end
for i in 1:length(restarted_trace)
println("$i: ", o3.vals[i], ", ", restarted_trace[i], ". Same? ", o3.vals[i]≈restarted_trace[i])
end |
Hi! I'm using this package for my research that involves estimating a complicated model. I've been using the NelderMead() algorithm. From time to time, the estimation would stop due to memory issue. I have three questions:
store_trace=true
. Could this be the cause of the memory issue?Thank you very much for your help!!
The text was updated successfully, but these errors were encountered: