Skip to content

Commit

Permalink
Model generation for ADT theory
Browse files Browse the repository at this point in the history
This PR implements the model generation for ADT. The model generation is
done by the casesplit mechanism in `Adt_rel`.

- If we turn model generation on, we performs casesplits even if the
  flag `--enable-adts-cs` isn't present in the command line.

- The termination of the model generation is a bit tricky in the case of
  mutually recursive ADT. Please see the tests added for some
  complicated examples. I hope that I caught all the corner cases.
  To ensure the termination, the basic idea is to sort ADT's
  constructors in the module `Ty` during the parsing and to use the fact
  that the SMT-LIB standard only accepts well-founded ADT.

  We choose constructors in domains with the following order:
  - Constructor with the less destructors using the same nest;
  - Constructor with the less destructors using another nest of the same
    mutually recursive declaration;
  • Loading branch information
Halbaroth committed Apr 13, 2024
1 parent 19db47a commit 38fb9a5
Show file tree
Hide file tree
Showing 19 changed files with 368 additions and 61 deletions.
1 change: 1 addition & 0 deletions src/bin/common/solving_loop.ml
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@ let main () =
st
| ":produce-models", Symbol { name = Simple "true"; _ } ->
Options.set_interpretation ILast;
Options.set_enable_adts_cs true;
st
| ":produce-models", Symbol { name = Simple "false"; _ } ->
Options.set_interpretation INone;
Expand Down
21 changes: 13 additions & 8 deletions src/lib/reasoners/adt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -405,12 +405,17 @@ module Shostak (X : ALIEN) = struct


let assign_value _ _ _ =
Printer.print_err
"[ADTs.models] assign_value currently not implemented";
raise (Util.Not_implemented "Models for ADTs")

let to_model_term _r =
Printer.print_err
"[ADTs.models] to_model_term currently not implemented";
raise (Util.Not_implemented "Models for ADTs")
(* Model generation is performed by the casesplit mechanism
in [Adt_rel]. *)
None

let to_model_term r =
match embed r with
| Constr { c_name; c_ty; c_args } ->
let args = Lists.try_map (fun (_, arg) -> X.to_model_term arg) c_args in
Option.bind args @@ fun args ->
Some (E.mk_term Sy.(Op (Constr c_name)) args c_ty)

| Select _ -> None
| Alien a -> X.to_model_term a
end
43 changes: 38 additions & 5 deletions src/lib/reasoners/adt_rel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ module Domain = struct

exception Inconsistent of Ex.t

let[@inline always] cardinal { constrs; _ } = HSS.cardinal constrs

let[@inline always] choose { constrs; _ } = HSS.choose constrs

let[@inline always] as_singleton { constrs; ex } =
if HSS.cardinal constrs = 1 then
Some (HSS.choose constrs, ex)
Expand Down Expand Up @@ -202,6 +206,8 @@ module Domains = struct
) t.changed acc
in
acc, { t with changed = SX.empty }

let iter f t = MX.iter f t.domains
end

let calc_destructor d e uf =
Expand Down Expand Up @@ -603,8 +609,12 @@ let pick_delayed_destructor env =
Rel_utils.Delayed.iter_delayed
(fun r sy _e ->
match sy with
| Sy.Destruct d ->
raise_notrace @@ Found (r, d)
| Sy.Destruct destr ->
let d = Domains.get r env.domains in
if Domain.cardinal d > 1 then
raise_notrace @@ Found (r, destr)
else
()
| _ ->
()
) env.delayed;
Expand All @@ -614,12 +624,35 @@ let pick_delayed_destructor env =
(* Do a case-split by choosing a semantic value [r] and constructor [c]
for which there are delayed destructor applications and propagate the
literal [(not (_ is c) r)]. *)
let case_split env _uf ~for_model =
if Options.get_disable_adts () || not (Options.get_enable_adts_cs())
let case_split env uf ~for_model =
if Options.get_disable_adts ()
|| not (Options.get_enable_adts_cs () || for_model)
then
[]
else if for_model then
try
Domains.iter
(fun r d ->
let rr, _ = Uf.find_r uf r in
match Th.embed rr with
| Constr _ -> ()
| _ ->
let c = Domain.choose d in
raise_notrace @@ Found (r, c)
) env.domains;
[]
with Found (r, c) ->
match build_constr_eq r c with
| Some (_, cons) ->
let nr, _ = X.make cons in
let cs = LR.mkv_eq r nr in
if Options.get_debug_adt () then
Printer.print_dbg ~flushed:false
~module_name:"Adt_rel" ~function_name:"case_split"
"Assume %a = %a" X.print r Hstring.print c;
[ cs, true, Th_util.CS (Th_util.Th_adt, two) ]
| None -> assert false
else begin
assert (not for_model);
if Options.get_debug_adt () then Debug.pp_env "before cs" env;
match pick_delayed_destructor env with
| Some (r, d) ->
Expand Down
3 changes: 2 additions & 1 deletion src/lib/reasoners/shostak.ml
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,8 @@ struct
not (Options.get_restricted ()) &&
(RECORDS.is_mine_symb sb ty ||
BITV.is_mine_symb sb ty ||
ENUM.is_mine_symb sb ty)
ENUM.is_mine_symb sb ty ||
ADT.is_mine_symb sb ty)

let is_a_leaf r = match r.v with
| Term _ | Ac _ -> true
Expand Down
10 changes: 7 additions & 3 deletions src/lib/reasoners/uf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,10 @@ type cache = {
to ensure we don't generate twice an abstract value for a given symbol. *)
}

let is_destructor = function
| Sy.Op (Destruct _) -> true
| _ -> false

(* The environment of the union-find contains almost a first-order model.
There are two situations that require some computations to retrieve an
appropriate model value:
Expand All @@ -1084,9 +1088,9 @@ let compute_concrete_model_of_val cache =
and get_abstract_for = Cache.get_abstract_for cache.abstracts
in fun env t ((mdl, mrepr) as acc) ->
let { E.f; xs; ty; _ } = E.term_view t in
if X.is_solvable_theory_symbol f ty
|| Sy.is_internal f || E.is_internal_name t || E.is_internal_skolem t
|| E.equal t E.vrai || E.equal t E.faux
if X.is_solvable_theory_symbol f ty || is_destructor f
|| Sy.is_internal f || E.is_internal_name t || E.is_internal_skolem t
|| E.equal t E.vrai || E.equal t E.faux
then
(* These terms are built-in interpreted ones and we don't have
to produce a definition for them. *)
Expand Down
140 changes: 96 additions & 44 deletions src/lib/structures/ty.ml
Original file line number Diff line number Diff line change
Expand Up @@ -463,54 +463,55 @@ module Decls = struct
MH.add name {decl = (params, body); instances = MTY.empty} !decls

let body name args =
let {decl = (params, body); instances} = MH.find name !decls in
try
let {decl = (params, body); instances} = MH.find name !decls in
try
if compare_list params args = 0 then body
else MTY.find args instances
(* should I instantiate if not found ?? *)
with Not_found ->
let params, body = fresh_type params body in
(*if true || get_debug_adt () then*)
let sbt =
try
List.fold_left2
(fun sbt vty ty ->
let vty = shorten vty in
match vty with
| Tvar { value = Some _ ; _ } -> assert false
| Tvar {v ; value = None} ->
if equal vty ty then sbt else M.add v ty sbt
| _ ->
Printer.print_err "vty = %a and ty = %a"
print vty print ty;
assert false
)M.empty params args
with Invalid_argument _ -> assert false
in
let body = match body with
| Adt cases ->
Adt(
List.map
(fun {constr; destrs} ->
{constr;
destrs =
List.map (fun (d, ty) -> d, apply_subst sbt ty) destrs }
) cases
)
in
let params = List.map (fun ty -> apply_subst sbt ty) params in
add name params body;
body
if compare_list params args = 0 then body
else MTY.find args instances
(* should I instantiate if not found ?? *)
with Not_found ->
Printer.print_err "%a not found" Hstring.print name;
assert false
let params, body = fresh_type params body in
(*if true || get_debug_adt () then*)
let sbt =
try
List.fold_left2
(fun sbt vty ty ->
let vty = shorten vty in
match vty with
| Tvar { value = Some _ ; _ } -> assert false
| Tvar {v ; value = None} ->
if equal vty ty then sbt else M.add v ty sbt
| _ ->
Printer.print_err "vty = %a and ty = %a"
print vty print ty;
assert false
)M.empty params args
with Invalid_argument _ -> assert false
in
let body = match body with
| Adt cases ->
Adt(
List.map
(fun {constr; destrs} ->
{constr;
destrs =
List.map (fun (d, ty) -> d, apply_subst sbt ty) destrs }
) cases
)
in
let params = List.map (fun ty -> apply_subst sbt ty) params in
add name params body;
body

let reinit () = decls := MH.empty

end

let type_body name args = Decls.body name args
let type_body name args =
try
Decls.body name args
with Not_found ->
Printer.print_err "%a not found" Hstring.print name;
assert false


(* smart constructors *)
Expand All @@ -524,6 +525,56 @@ let fresh_empty_text =

let tsum s lc = Tsum (Hstring.make s, List.map Hstring.make lc)

(* Count the number of occurences of the nest of [name] in the signature
of payload [l]. *)
let count_same_nest name l =
Lists.sum
(fun (_, ty) ->
match ty with
| Tadt (name', _) when Hstring.equal name name' -> 1
| _ -> 0
) l

(* Count the number of occurences of the (mutually recursive) ADT type of
[name] in the signature of payload [l]. *)
let count_same_adt name l =
Lists.sum
(fun (_, ty) ->
match ty with
| Tadt (name', params) ->
(* TODO: this is a hackish way to check that `name'` and `name` are
two nests of the same mutually recursive ADT. We should store
ADT's nests in a data structure as it's done in Dolmen. *)
begin try
let Adt cases = Decls.body name' params in
List.exists
(fun { destrs; _ } ->
List.exists
(fun (_, ty') ->
match ty' with
| Tadt (name'', _) -> Hstring.equal name name''
| _ -> false
) destrs
) cases
|> Bool.to_int
with Not_found ->
(* If we haven't already register the nest [name'], it means that
[name] and [name'] are two nests of the same ADT. *)
1
end
| _ -> 0
) l

(* Comparison function used to ensure the termination of the model
generation for recursive ADT values. *)
let cons_weight name (_, l1) (_, l2) =
let c = count_same_nest name l1 - count_same_nest name l2 in
if c <> 0 then c
else
let c = count_same_adt name l1 - count_same_adt name l2 in
if c <> 0 then c
else List.compare_lengths l1 l2

let t_adt ?(body=None) s ty_vars =
let hs = Hstring.make s in
let ty = Tadt (hs, ty_vars) in
Expand All @@ -545,12 +596,13 @@ let t_adt ?(body=None) s ty_vars =
Decls.add hs ty_vars (Adt cases)
| Some cases ->
let cases =
List.map (fun (s, l) ->
List.stable_sort (cons_weight hs) cases |>
List.map (fun (c, l) ->
let l =
List.map (fun (d, e) -> Hstring.make d, e) l
in
{constr = Hstring.make s; destrs = l}
) cases
{constr = Hstring.make c; destrs = l}
)
in
Decls.add hs ty_vars (Adt cases)
end;
Expand Down
3 changes: 3 additions & 0 deletions src/lib/util/lists.ml
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,6 @@ let rec is_sorted cmp l =
match l with
| x :: y :: xs -> cmp x y <= 0 && is_sorted cmp (y :: xs)
| [_] | [] -> true

let sum f l =
List.fold_left (fun sum i -> sum + f i) 0 l
4 changes: 4 additions & 0 deletions src/lib/util/lists.mli
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,7 @@ val try_map : ('a -> 'b option) -> 'a list -> 'b list option
val is_sorted : ('a -> 'a -> int) -> 'a list -> bool
(** [is_sorted cmp l] checks that [l] is sorted for the comparison function
[cmp]. *)

val sum : ('a -> int) -> 'a list -> int
(** [sum f l] computes the sum [f a1 + f a2 + ...] where
[l = [a1; a2; ...]]. *)
Loading

0 comments on commit 38fb9a5

Please sign in to comment.