Skip to content

Commit

Permalink
typebased stlc, tocoq, dependents
Browse files Browse the repository at this point in the history
  • Loading branch information
rtjoa committed May 1, 2024
1 parent 585a125 commit 2c411c7
Show file tree
Hide file tree
Showing 9 changed files with 373 additions and 79 deletions.
56 changes: 38 additions & 18 deletions examples/qc/benchmarks/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ function produce_loss(rs::RunState, m::SamplingEntropyLossMgr, epoch::Integer)
LogPr(prob_equals(m.val,sample))
end
lpr_eq_expanded = Dice.expand_logprs(l, lpr_eq)
diff_test_typecheck(sample, Dice.frombits(sample, Dict()))
if m.consider(sample)
num_meeting += 1
[lpr_eq_expanded * compute(a, lpr_eq_expanded), lpr_eq_expanded]
Expand Down Expand Up @@ -365,15 +366,20 @@ function generate(rs::RunState, p::BespokeSTLCGenerator)
STLCGeneration(e, constructors_overapproximation)
end

function save_coq_generator(rs, p, s, f)
path = joinpath(rs.out_dir, "$(s)_Generator.v")
open(path, "w") do file
vals = compute(rs.var_vals, values(rs.adnodes_of_interest))
adnodes_vals = Dict(s => vals[adnode] for (s, adnode) in rs.adnodes_of_interest)
println(file, f(p, adnodes_vals, rs.io))
end
println_flush(rs.io, "Saved Coq generator to $(path)")
end

function generation_params_emit_stats(rs::RunState, p::BespokeSTLCGenerator, s)
path = joinpath(rs.out_dir, "$(s)_Generator.v")
if p == BespokeSTLCGenerator(param_vars_by_size=true,size=5,ty_size=2)
path = joinpath(rs.out_dir, "$(s)_Generator.v")
open(path, "w") do file
vals = compute(rs.var_vals, values(rs.adnodes_of_interest))
adnodes_vals = Dict(s => vals[adnode] for (s, adnode) in rs.adnodes_of_interest)
println(file, bespoke_stlc_to_coq(adnodes_vals))
end
println_flush(rs.io, "Saved Coq generator to $(path)")
save_coq_generator(rs, p, s, bespoke_stlc_to_coq)
else
println_flush(rs.io, "Translation back to Coq not defined")
end
Expand All @@ -387,23 +393,30 @@ end
struct TypeBasedSTLCGenerator <: GenerationParams{STLC}
size::Integer
ty_size::Integer
dependents::Vector{Symbol}
ty_dependents::Vector{Symbol}
end
TypeBasedSTLCGenerator(; size, ty_size) = TypeBasedSTLCGenerator(size, ty_size)
TypeBasedSTLCGenerator(; size, ty_size, dependents, ty_dependents) = TypeBasedSTLCGenerator(size, ty_size, dependents, ty_dependents)
function to_subpath(p::TypeBasedSTLCGenerator)
[
"stlc",
"typebased",
"sz=$(p.size)-tysz=$(p.ty_size)",
"dependents=$(join(Base.map(string, p.dependents),"-"))",
"ty_dependents=$(join(Base.map(string, p.ty_dependents),"-"))",
]
end
function generate(rs::RunState, p::TypeBasedSTLCGenerator)
constructors_overapproximation = []
function add_ctor(v::Expr.T)
push!(constructors_overapproximation, DistSome(v))
push!(constructors_overapproximation, Opt.Some(Expr.T, v))
v
end
e = tb_gen_expr(rs, p.size, p.ty_size, add_ctor)
STLCGeneration(DistSome(e), constructors_overapproximation)
e = tb_gen_expr(rs, p, p.size, 20, add_ctor)
STLCGeneration(Opt.Some(Expr.T, e), constructors_overapproximation)
end
function generation_params_emit_stats(rs::RunState, p::TypeBasedSTLCGenerator, s)
save_coq_generator(rs, p, s, typebased_stlc_to_coq)
end

##################################
Expand Down Expand Up @@ -643,13 +656,7 @@ function generate(rs::RunState, p::TypeBasedRBTGenerator)
RBTGeneration(tb_gen_rbt(rs, p, p.size, Color.Black(), 10))
end
function generation_params_emit_stats(rs::RunState, p::TypeBasedRBTGenerator, s)
path = joinpath(rs.out_dir, "$(s)_Generator.v")
open(path, "w") do file
vals = compute(rs.var_vals, values(rs.adnodes_of_interest))
adnodes_vals = Dict(s => vals[adnode] for (s, adnode) in rs.adnodes_of_interest)
println(file, typebased_rbt_to_coq(p, adnodes_vals, rs.io))
end
println_flush(rs.io, "Saved Coq generator to $(path)")
save_coq_generator(rs, p, s, typebased_rbt_to_coq)
end

##################################
Expand Down Expand Up @@ -685,6 +692,19 @@ function create_loss_manager(rs::RunState, p::SatisfyPropertyLoss, generation)
# end
end

struct STLCWellTyped <: Property{STLC} end
function check_property(::STLCWellTyped, e::Opt.T{Expr.T})
@assert isdeterministic(e)
@match e [
Some(e) -> (@match typecheck(e) [
Some(_) -> true,
None() -> false,
]),
None() -> false,
]
end
name(::STLCWellTyped) = "stlcwelltyped"

struct BookkeepingInvariant <: Property{RBT} end
check_property(::BookkeepingInvariant, t::ColorKVTree.T) =
satisfies_bookkeeping_invariant(t)
Expand Down
1 change: 1 addition & 0 deletions examples/qc/benchmarks/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ include("lruset/generator.jl")
include("stlc/dist.jl")
include("stlc/generator.jl")
include("stlc/to_coq.jl")
include("stlc/to_coq_tb.jl")
include("bst/dist.jl")
include("bst/generator.jl")
include("rbt/dist.jl")
Expand Down
4 changes: 0 additions & 4 deletions examples/qc/benchmarks/lib/rbt/to_coq.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
flatten = Iterators.flatten

function tocoq(i::Integer)
"$(i)"
end

function tocoq(c::Color.T)
@match c [
Red() -> "R",
Expand Down
131 changes: 130 additions & 1 deletion examples/qc/benchmarks/lib/stlc/dist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,80 @@ function get_error(ty)
ty
end

function opt_map(f, x::Tuple)
name, children = x
if name == :Some
e, = children
f(e)
elseif name == :None
nothing
else
error()
end
end

function opt_map(f, x::Opt.T)
@match x [
None() -> nothing,
Some(x) -> f(x),
]
end

function diff_test_typecheck(expr_dist, expr)
@assert isdeterministic(expr_dist)
opt_map(expr_dist) do expr_dist
opt_map(expr) do expr
ty1 = typecheck(expr)
ty2_dist = pr(typecheck(expr_dist))
@assert length(ty2_dist) == 1
ty2 = first(keys(ty2_dist))
if error_ty(ty1)
@assert ty2 == (:None, [])
else
@assert ty2 == (:Some, [ty1]) "$ty1 $ty2"
end
end
end
end

function to_int(x::DistUInt32)
dist = pr(x)
@assert length(dist) == 1
first(keys(dist))
end

function typecheck(ast::Expr.T, gamma, depth=0)::Opt.T{Typ.T}
@match ast [
Var(i) -> begin
var_depth = depth - to_int(i) - 1
haskey(gamma, var_depth) || return Opt.None(Typ.T)
Opt.Some(gamma[var_depth])
end,
Boolean(_) -> Opt.Some(Typ.TBool()),
Abs(t_in, e) -> begin
gamma′ = copy(gamma)
gamma′[depth] = t_in
Opt.map(Typ.T, typecheck(e, gamma′, depth + 1)) do t_out
Typ.TFun(t_in, t_out)
end
end,
App(e1, e2) -> begin
Opt.bind(Typ.T, typecheck(e1, gamma, depth)) do t1
@match t1 [
TBool() -> Opt.None(Typ.T),
TFun(t1_in, t1_out) -> Opt.bind(Typ.T, typecheck(e2, gamma, depth)) do t2
if prob_equals(t1_in, t2)
Opt.Some(t1_out)
else
Opt.None(Typ.T)
end
end,
]
end
end,
]
end

function typecheck_opt(ast)
name, children = ast
if name == :Some
Expand All @@ -159,7 +233,7 @@ end

typecheck(ast) = typecheck(ast, Dict())

function typecheck(ast, gamma, depth=0)
function typecheck(ast::Tuple, gamma, depth=0)
name, children = ast
if name == :Var
i, = children
Expand Down Expand Up @@ -196,3 +270,58 @@ function typecheck(ast, gamma, depth=0)
error("Bad node $(name)")
end
end

function eq_except_numbers(x::Typ.T, y::Typ.T)
@match x [
TBool() -> (@match y [
TBool() -> true,
TFun(_, _) -> false,
]),
TFun(a1, b1) -> (@match y [
TBool() -> false,
TFun(a2, b2) -> eq_except_numbers(a1, a2) & eq_except_numbers(b1, b2),
]),
]
end

function eq_except_numbers(x::Expr.T, y::Expr.T)
@match x [
Var(_) -> (@match y [
Var(_) -> true,
Boolean(_) -> false,
App(_, _) -> false,
Abs(_, _) -> false,
]),
Boolean(_) -> (@match y [
Var(_) -> false,
Boolean(_) -> true,
App(_, _) -> false,
Abs(_, _) -> false,
]),
App(f1, x1) -> (@match y [
Var(_) -> false,
Boolean(_) -> false,
App(f2, x2) -> eq_except_numbers(f1, f2) & eq_except_numbers(x1, x2),
Abs(_, _) -> false,
]),
Abs(ty1, e1) -> (@match y [
Var(_) -> false,
Boolean(_) -> false,
App(_, _) -> false,
Abs(ty2, e2) -> eq_except_numbers(ty1, ty2) & eq_except_numbers(e1, e2),
]),
]
end

function eq_except_numbers(x::Opt.T{T}, y::Opt.T{T}) where T
@match x [
Some(xv) -> (@match y [
Some(yv) -> eq_except_numbers(xv, yv),
None() -> false,
]),
None() -> (@match y [
Some(_) -> false,
None() -> true,
])
]
end
65 changes: 39 additions & 26 deletions examples/qc/benchmarks/lib/stlc/generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,45 +92,58 @@ function gen_expr(rs::RunState, env::Ctx, tau::Typ.T, sz::Integer, gen_typ_sz::I
)
end

function tb_gen_expr(rs::RunState, sz::Integer, ty_sz, track_return)

function tb_gen_expr(rs::RunState, p, size::Integer, last_callsite, track_return)
function get_dependent_dist(dependent)
if dependent == :size size
elseif dependent == :last_callsite last_callsite
else error() end
end
dependent_dists = [get_dependent_dist(d) for d in p.dependents]
track_return(
if sz == 0
@dice_ite if flip(register_weight!(rs, "sz$(sz)_pvar"))
DistVar(DistNat(0)) # really, this is arbitrary
if size == 0
@dice_ite if flip_for(rs, "pvar", dependent_dists)
Expr.Var(DistNat(0)) # really, this is arbitrary
else
DistBoolean(true) # really, this is arbitrary
Expr.Boolean(true) # really, this is arbitrary
end
else
sz′ = sz - 1
frequency_for(rs, "sz$(sz)_freq", [
DistVar(DistNat(0)), # really, this is arbitrary
DistBoolean(true), # really, this is arbitrary
begin
typ = tb_gen_type(rs, ty_sz) # TODO
e = tb_gen_expr(rs, sz′, ty_sz, track_return)
DistAbs(typ, e)
sz′ = size - 1
frequency_for(rs, "freq", dependent_dists, [
"var" => Expr.Var(DistNat(0)), # really, this is arbitrary
"boolean" => Expr.Boolean(true), # really, this is arbitrary
"abs" => begin
typ = tb_gen_type(rs, p, p.ty_size, 10) # TODO
e = tb_gen_expr(rs, p, sz′, 11, track_return)
Expr.Abs(typ, e)
end,
begin
e1 = tb_gen_expr(rs, sz′, ty_sz, track_return)
e2 = tb_gen_expr(rs, sz′, ty_sz, track_return)
DistApp(e1, e2)
"app" => begin
e1 = tb_gen_expr(rs, p, sz′, 12, track_return)
e2 = tb_gen_expr(rs, p, sz′, 13, track_return)
Expr.App(e1, e2)
end,
])
end
)
end

function tb_gen_type(rs::RunState, sz::Integer)
if sz == 0
DistTBool()
function tb_gen_type(rs::RunState, p, size::Integer, last_callsite)
function get_dependent_dist(dependent)
if dependent == :size size
elseif dependent == :last_callsite last_callsite
else error() end
end
dependent_dists = [get_dependent_dist(d) for d in p.ty_dependents]
if size == 0
Typ.TBool()
else
sz′ = sz - 1
@dice_ite if flip(register_weight!(rs, "tysz$(sz)_ptbool"))
DistTBool()
sz′ = size - 1
@dice_ite if flip_for(rs, "ptbool", dependent_dists)
Typ.TBool()
else
ty1 = tb_gen_type(rs, sz′)
ty2 = tb_gen_type(rs, sz′)
DistTFun(ty1, ty2)
ty1 = tb_gen_type(rs, p, sz′, 14)
ty2 = tb_gen_type(rs, p, sz′, 15)
Typ.TFun(ty1, ty2)
end
end
end
2 changes: 1 addition & 1 deletion examples/qc/benchmarks/lib/stlc/to_coq.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function bespoke_stlc_to_coq(adnodes_of_interest)
function bespoke_stlc_to_coq(_p, adnodes_of_interest, _io)
@assert issetequal(keys(adnodes_of_interest), ["sz1_succ_abs", "tysz2_gen_type_tbool", "sz3_succ_abs", "sz4_succ_var", "sz3_succ_app", "sz5_succ_app", "tysz1_gen_type_tbool", "sz0_zero_pr_var2", "sz2_succ_app", "sz4_succ_abs", "sz5_succ_var", "sz4_succ_app", "sz2_succ_abs", "sz5_succ_abs", "sz3_succ_var", "sz2_succ_var", "sz1_succ_var", "sz1_succ_app"])
w(s) = thousandths(adnodes_of_interest[s])
"""
Expand Down
Loading

0 comments on commit 2c411c7

Please sign in to comment.