Skip to content

Commit

Permalink
Merge branch 'qc7.5-approx-entropy' into meta-autodiff
Browse files Browse the repository at this point in the history
  • Loading branch information
rtjoa committed Jan 24, 2024
2 parents e15046c + f3bfafe commit 7f1de44
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 10 deletions.
16 changes: 16 additions & 0 deletions examples/qc/benchmarks/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,22 @@ function build_loss(::STLCConstructorEntropy, generation::STLCGeneration)
loss, nothing
end

##################################
# STLC "4231" (few apps) loss
##################################

struct STLC4321AppsLoss <: LossParams{STLC} end
to_subpath(::STLC4321AppsLoss) = ["4321apps"]
function build_loss(::STLC4321AppsLoss, generation::STLCGeneration)
metric = num_apps(generation.e)
mle_loss([
BoolToMax(prob_equals(metric, DistUInt32(0)), weight=.4),
BoolToMax(prob_equals(metric, DistUInt32(1)), weight=.3),
BoolToMax(prob_equals(metric, DistUInt32(2)), weight=.2),
BoolToMax(prob_equals(metric, DistUInt32(3)), weight=.1),
]), nothing
end

##################################
# BST generation
##################################
Expand Down
34 changes: 34 additions & 0 deletions examples/qc/benchmarks/format.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
adnodes_of_interest =

Dict("tysz2_gen_type_tbool" => 0.7207727519197863, "sz4_succ_var" => 0.4903234929252827, "tysz1_gen_type_tbool" => 0.5079797882722665, "sz0_zero_pr_var2" => 0.5, "sz2_succ_app" => 0.08352877939489839, "sz4_succ_abs" => 0.7066063985105583, "sz5_succ_var" => 0.5, "sz5_succ_abs" => 0.41721675684910503, "sz2_succ_var" => 0.5435857870063115, "sz1_succ_var" => 0.5460302801221573, "sz1_succ_app" => 0.1617094113296376, "sz1_succ_abs" => 0.6925946397620913, "sz3_succ_abs" => 0.7633559861209368, "sz3_succ_app" => 0.06947521647387661, "sz5_succ_app" => 0.5711467012997115, "sz4_succ_app" => 0.19063677116251065, "sz2_succ_abs" => 0.7423080357841665, "sz3_succ_var" => 0.5151953122355372)
@assert issetequal(keys(adnodes_of_interest), ["sz1_succ_abs", "tysz2_gen_type_tbool", "sz3_succ_abs", "sz4_succ_var", "sz3_succ_app", "sz5_succ_app", "tysz1_gen_type_tbool", "sz0_zero_pr_var2", "sz2_succ_app", "sz4_succ_abs", "sz5_succ_var", "sz4_succ_app", "sz2_succ_abs", "sz5_succ_abs", "sz3_succ_var", "sz2_succ_var", "sz1_succ_var", "sz1_succ_app"])

thousandths(n) = Integer(round(n, digits=3) * 1000)
w(s) = thousandths(adnodes_of_interest[s])

println("""genType
let '(boolWeight, funWeight) :=
get
[
(1, ($(w("tysz1_gen_type_tbool")), 1000-$(w("tysz1_gen_type_tbool"))));
(2, ($(w("tysz2_gen_type_tbool")), 2000-$(w("tysz2_gen_type_tbool"))))
]
Fixpoint genExpr env tau (sz: nat) : G (option Expr) :=
match sz with
| 0 =>
let '(var_weight, zero_weight) := ($(w("sz0_zero_pr_var2")), 1000-$(w("sz0_zero_pr_var2"))) in
backtrack
[(var_weight, oneOf_ (ret None) (map (fun x => returnGen (Some (Var x))) (genVar' env tau 0 [])))
;(zero_weight, genZero env tau)]
| S sz' =>
let '(val_weight, app_weight, var_weight) :=
get
[
(1, ($(w("sz1_succ_abs")), $(w("sz1_succ_app")), $(w("sz1_succ_var"))));
(2, ($(w("sz2_succ_abs")), $(w("sz2_succ_app")), $(w("sz2_succ_var"))));
(3, ($(w("sz3_succ_abs")), $(w("sz3_succ_app")), $(w("sz3_succ_var"))));
(4, ($(w("sz4_succ_abs")), $(w("sz4_succ_app")), $(w("sz4_succ_var"))));
(5, ($(w("sz5_succ_abs")), $(w("sz5_succ_app")), $(w("sz5_succ_var"))))
]
""")
32 changes: 22 additions & 10 deletions examples/qc/benchmarks/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,16 @@ include("benchmarks.jl")

generation_params = STLCGenerationParams(
param_vars_by_size=true,
size=3,
ty_size=1,
size=5,
ty_size=2,
)
loss_params = ApproxSTLCConstructorEntropy()
# loss_params = MixedLossParams([
# ApproxSTLCConstructorEntropy() => 10,
# MLELossParams(metric=NumApps(), target_dist=Uniform()) => 1,
# ])
loss_params = STLC4321AppsLoss()

# generation_params = BSTGenerationParams(size=3)
# loss_params = ApproxBSTConstructorEntropy()

EPOCHS = 2000
LEARNING_RATE = if loss_params isa ApproxSTLCConstructorEntropy 0.03 else 0.003 end
EPOCHS = 500
LEARNING_RATE = if loss_params isa ApproxSTLCConstructorEntropy 0.03 else 0.01 end

TAG = "v04_infra"

Expand Down Expand Up @@ -74,7 +70,11 @@ var_vals = Valuation()
adnodes_of_interest = Dict{String, ADNode}()
function register_weight!(s)
var = Var("$(s)_before_sigmoid")
var_vals[var] = 0
if !haskey(var_vals, var) || var_vals[var] == 0
var_vals[var] = 0
else
println(io, "WARNING: not registering fresh weight for $(s)")
end
weight = sigmoid(var)
adnodes_of_interest[s] = weight
weight
Expand All @@ -95,10 +95,16 @@ println(io)
# Before
############################

println(io, "Initial var_vals:")
show(io, var_vals)
println(io)
println(io)

println(io, "Initial adnodes_of_interest:")
vals = compute(var_vals, values(adnodes_of_interest))
show(io, Dict(s => vals[adnode] for (s, adnode) in adnodes_of_interest))
println(io)
println(io)

if loss_params isa MLELossParams{STLC}
println_flush(io, "Inferring initial distribution...")
Expand Down Expand Up @@ -148,10 +154,16 @@ if loss_params isa MLELossParams
end
println(io)

println(io, "Learned var_vals:")
show(io, var_vals)
println(io)
println(io)

println(io, "Learned adnodes_of_interest:")
vals = compute(var_vals, values(adnodes_of_interest))
show(io, Dict(s => vals[adnode] for (s, adnode) in adnodes_of_interest))
println(io)
println(io)

if generation isa STLCGeneration
println(io, "Inferring trained num apps distribution...")
Expand Down

0 comments on commit 7f1de44

Please sign in to comment.