Skip to content

Commit

Permalink
generalize train, two curves
Browse files Browse the repository at this point in the history
  • Loading branch information
rtjoa committed Mar 9, 2024
1 parent c9447c4 commit 0ae9b57
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 22 deletions.
26 changes: 17 additions & 9 deletions examples/qc/benchmarks/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ end
emit_stats(m::SimpleLossMgr, tag) = m.emit_stats(tag)
train!(m::SimpleLossMgr; epochs, learning_rate) = m.train!(; epochs, learning_rate)

function save_learning_curve(out_dir, learning_curve)
open(joinpath(out_dir, "learning_curve.csv"), "w") do file
function save_learning_curve(out_dir, learning_curve, name)
open(joinpath(out_dir, "$(name).csv"), "w") do file
xs = 0:length(learning_curve)-1
for (epoch, logpr) in zip(xs, learning_curve)
println(file, "$(epoch)\t$(logpr)")
end
plot(xs, learning_curve)
savefig(joinpath(out_dir, "learning_curve.svg"))
savefig(joinpath(out_dir, "$(name).svg"))
end
end

Expand All @@ -86,7 +86,7 @@ function create_simple_loss_manager(loss, io, out_dir, var_vals)
println(io, " $(time_train) seconds")
println(io)

save_learning_curve(out_dir, learning_curve)
save_learning_curve(out_dir, learning_curve, "loss")
end
SimpleLossMgr(emit_stats, f_train, loss)
end
Expand All @@ -95,6 +95,8 @@ function train_via_sampling_entropy!(io, out_dir, var_vals, e; epochs, learning_
learning_rate = learning_rate / samples_per_batch

learning_curve = []
additional_learning_curve = []

time_sample = 0
time_step = 0
println_flush(io, "Training...")
Expand All @@ -119,13 +121,18 @@ function train_via_sampling_entropy!(io, out_dir, var_vals, e; epochs, learning_
last_batch = epochs_done + epochs_this_batch == epochs

println_flush(io, "Stepping...")
time_step_here = @elapsed subcurve = Dice.train!(var_vals,
loss + (additional_loss * additional_loss_lr / learning_rate)
; epochs=epochs_this_batch, learning_rate, append_last_loss=last_batch)
time_step_here = @elapsed subcurve, additional_subcurve = Dice.train!(
var_vals,
[loss => learning_rate, additional_loss => additional_loss_lr];
epochs=epochs_this_batch, append_last_loss=last_batch)
time_step += time_step_here
append!(learning_curve, subcurve)
append!(additional_learning_curve, additional_subcurve)
println(io, " $(time_step_here) seconds")
if isinf(last(learning_curve)) || isnan(last(learning_curve))
if (isinf(last(learning_curve)) || isnan(last(learning_curve))
|| isinf(last(additional_learning_curve))
|| isnan(last(additional_learning_curve))
)
println(io, "Stopping early due to Inf/NaN loss")
break
end
Expand All @@ -135,7 +142,8 @@ function train_via_sampling_entropy!(io, out_dir, var_vals, e; epochs, learning_
println(io, "Step time: $(time_step) seconds")
println(io)

save_learning_curve(out_dir, learning_curve)
save_learning_curve(out_dir, learning_curve, "sampling_loss")
save_learning_curve(out_dir, additional_learning_curve, "additional_loss")
end

##################################
Expand Down
56 changes: 43 additions & 13 deletions src/autodiff_pr/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,34 +77,64 @@ end

function train!(
var_vals::Valuation,
loss::ADNode;
loss_lrs::Vector{<:Pair{<:ADNode, <:Real}};
epochs::Integer,
learning_rate::Real,
append_last_loss=true,
stop_if_inf_or_nan=true,
)
losses = []
l = LogPrExpander(WMC(BDDCompiler(bool_roots([loss]))))
loss = expand_logprs(l, loss)
# Unzip
losses = ADNode[]
lrs = Real[]
for (loss, lr) in loss_lrs
push!(losses, loss)
push!(lrs, lr)
end

# Expand
l = LogPrExpander(WMC(BDDCompiler(bool_roots(losses))))
losses = [expand_logprs(l, loss) for loss in losses]

curves = [[] for _ in 1:length(losses)]
function update_curves(vals)
for (i, loss) in enumerate(losses)
push!(curves[i], vals[loss])
end
end

for _ in 1:epochs
vals, derivs = differentiate(var_vals, Derivs(loss => 1))
vals, derivs = differentiate(
var_vals,
Derivs(loss => lr for (loss, lr) in zip(losses, lrs))
)

if stop_if_inf_or_nan && (isinf(vals[loss]) || isnan(vals[loss]))
push!(losses, vals[loss])
return losses
if stop_if_inf_or_nan && any(isinf(vals[loss]) || isnan(vals[loss]) for loss in losses)
update_curves(vals)
return curves
end

# update vars
for (adnode, d) in derivs
if adnode isa Var
var_vals[adnode] -= d * learning_rate
var_vals[adnode] -= d
end
end

push!(losses, vals[loss])
update_curves(vals)
end
append_last_loss && push!(losses, compute_mixed(var_vals, loss))
losses

append_last_loss && update_curves(compute(var_vals, losses))
curves
end

function train!(
var_vals::Valuation,
loss::ADNode;
epochs::Integer,
learning_rate::Real,
append_last_loss=true,
stop_if_inf_or_nan=true,
)
train!(var_vals, [loss => learning_rate]; epochs, append_last_loss, stop_if_inf_or_nan)[1]
end

function collect_flips(bools)
Expand Down

0 comments on commit 0ae9b57

Please sign in to comment.