diff --git a/src/lib/reasoners/adt_rel.ml b/src/lib/reasoners/adt_rel.ml index 18eb871b3..0009c3298 100644 --- a/src/lib/reasoners/adt_rel.ml +++ b/src/lib/reasoners/adt_rel.ml @@ -191,7 +191,7 @@ module Domains = struct with Not_found -> Domain.unknown (X.type_info r) - let add r t = + let init r t = match Th.embed r with | Alien _ when not (MX.mem r t.domains) -> (* We have to add a default domain if the key `r` is not in map in order @@ -236,7 +236,7 @@ module Domains = struct let t = remove r t in tighten nr nd t - | exception Not_found -> add nr t + | exception Not_found -> init nr t (* [propagate f a t] iterates on all the changed domains of [t] since the last call of [propagate]. The list of changed domains is flushed after @@ -431,7 +431,7 @@ let add r uf domains = | Ty.Tadt _ -> Debug.add r; let rr, _ = Uf.find_r uf r in - Domains.add rr domains + Domains.init rr domains | _ -> domains diff --git a/src/lib/reasoners/bitv_rel.ml b/src/lib/reasoners/bitv_rel.ml index 2f39c5b3a..a84cd6285 100644 --- a/src/lib/reasoners/bitv_rel.ml +++ b/src/lib/reasoners/bitv_rel.ml @@ -89,6 +89,12 @@ module Interval_domain = struct let add_explanation = Intervals.Int.add_explanation + type constant = Z.t + + let constant n = Intervals.Int.of_bounds (Closed n) (Closed n) + + let filter_ty = is_bv_ty + let unknown = function | Ty.Tbitv n -> Intervals.Int.of_bounds @@ -97,8 +103,6 @@ module Interval_domain = struct Fmt.invalid_arg "unknown: only bit-vector types are supported; got %a" Ty.print ty - let filter_ty = is_bv_ty - let intersect x y = match Intervals.Int.intersect x y with | Empty ex -> @@ -108,128 +112,131 @@ module Interval_domain = struct let lognot sz int = Intervals.Int.extract ~ofs:0 ~len:sz @@ Intervals.Int.lognot int - let fold_signed f { Bitv.value; negated } sz int acc = - f value (if negated then lognot sz int else int) acc - - let point ?ex n = - Intervals.Int.of_bounds ?ex (Closed n) (Closed n) - - let fold_leaves f r int acc = - let width = bitwidth r in - let j, acc = - List.fold_left (fun (j, acc) { Bitv.bv; sz } -> - (* sz = j - i + 1 => i = j - sz + 1 *) - let int = Intervals.Int.extract int ~ofs:(j - sz + 1) ~len:sz in - - let acc = match bv with - | Bitv.Cte z -> - (* Nothing to update, but still check for consistency *) - ignore @@ intersect int (point z); - acc - | Other s -> fold_signed f s sz int acc - | Ext (r, r_size, i, j) -> - (* r = bl -> r = ?^(r_size - j - 1) @ bl @ ?^i *) - assert (i + r_size - j - 1 + sz = r_size); - let lo = unknown (Tbitv i) in - let int = Intervals.Int.scale Z.(~$1 lsl i) int in - let hi = unknown (Tbitv (r_size - j - 1)) in - let hi = - Intervals.Int.scale Z.(~$1 lsl Stdlib.(i + sz)) hi - in - fold_signed f r r_size Intervals.Int.(add hi (add int lo)) acc - in - (j - sz), acc - ) (width - 1, acc) (Shostak.Bitv.embed r) - in - assert (j = -1); - acc + let add_offset d cte = + Intervals.Int.add d (Intervals.Int.of_bounds (Closed cte) (Closed cte)) - let map_signed f { Bitv.value; negated } sz = - if negated then lognot sz (f value) else f value - - let map_leaves f r = - List.fold_left (fun ival { Bitv.bv; sz } -> - let ival = Intervals.Int.scale Z.(~$1 lsl sz) ival in - Intervals.Int.add ival @@ - match bv with - | Bitv.Cte z -> point z - | Other s -> map_signed f s sz - | Ext (s, sz', i, j) -> - Intervals.Int.extract (map_signed f s sz') ~ofs:i ~len:(j - i + 1) - ) (point Z.zero) (Shostak.Bitv.embed r) + let sub_offset d cte = + Intervals.Int.sub d (Intervals.Int.of_bounds (Closed cte) (Closed cte)) end -module Interval_domains = Rel_utils.Domains_make(Interval_domain) +type 'a explained = { value : 'a ; explanation : Explanation.t } -module Bitlist_domain : Rel_utils.Domain with type t = Bitlist.t = struct - (* Note: these functions are not in [Bitlist] proper in order to avoid a - (direct) dependency from [Bitlist] to the [Shostak] module. *) +let explained ~ex value = { value ; explanation = ex } - include Bitlist +module ExplainedOrdered(V : Rel_utils.OrderedType) : + Rel_utils.OrderedType with type t = V.t explained = +struct + type t = V.t explained - let filter_ty = is_bv_ty - - let fold_signed sz f { Bitv.value; negated } bl acc = - let bl = if negated then extract (lognot bl) 0 sz else bl in - f value bl acc + let pp ppf { value; _ } = V.pp ppf value - let fold_leaves f r bl acc = - let sz = bitwidth r in - let (acc, _, _) = List.fold_left (fun (acc, bl, w) { Bitv.bv; sz } -> - (* Extract the bitlist associated with the current component *) - let mid = w - sz in - let bl_tail = extract bl 0 mid in - let bl = extract bl mid (w - mid) in + let compare { value = v1; _ } { value = v2; _ } = V.compare v1 v2 - match bv with - | Bitv.Cte z -> - assert (Z.numbits z <= sz); - (* Nothing to update, but still check for consistency! *) - ignore @@ intersect bl (exact z Ex.empty); - acc, bl_tail, mid - | Other r -> fold_signed sz f r bl acc, bl_tail, mid - | Ext (r, r_size, i, j) -> - (* r = bl -> r = ?^(r_size - j - 1) @ bl @ ?^i *) - assert (i + r_size - j - 1 + w - mid = r_size); - let hi = Bitlist.(extract unknown 0 (r_size - j - 1)) in - let lo = Bitlist.(extract unknown 0 i) in - let bl_hd = Bitlist.((hi lsl (j + 1)) lor (bl lsl i) lor lo) in - fold_signed r_size f r bl_hd acc, - bl_tail, - mid - ) (acc, bl, sz) (Shostak.Bitv.embed r) - in acc + module Set = Set.Make(struct + type nonrec t = t - let map_signed sz f { Bitv.value; negated } = - let bl = f value in - if negated then extract (lognot bl) 0 sz else bl + let compare = compare + end) - let map_leaves f r = - List.fold_left (fun bl { Bitv.bv; sz } -> - bl lsl sz lor - match bv with - | Bitv.Cte z -> extract (exact z Ex.empty) 0 sz - | Other r -> map_signed sz f r - | Ext (r, r_sz, i, j) -> - extract (map_signed r_sz f r) i (j - i + 1) - ) (exact Z.zero Ex.empty) (Shostak.Bitv.embed r) + module Map = Map.Make(struct + type nonrec t = t - let unknown = function - | Ty.Tbitv n -> extract unknown 0 n - | _ -> - (* Only bit-vector values can have bitlist domains. *) - invalid_arg "unknown" + let compare = compare + end) end -module Bitlist_domains = Rel_utils.Domains_make(Bitlist_domain) +module BitvNormalForm = struct + (** Normal form for bit-vector values. + + We decompose non-constant bit-vector compositions as a variable part, + where all constant bits are set to [0] and all high constant bits are + chopped off, and an offset with all the constant bits. We consider the + variable part atomic if it is a single non-negated variable. + + Assuming [x] and [y] are bit-vectors of width 2: + - [101 @ x] is [x + 10100] ; + - [10 @ x @ 01] is [(x @ 00) + 100001] ; + - [10 @ y<0, 0> @ y<1, 1>] is [(y<0, 0> @ y<1>1) + 1000] ; + - [10 @ x @ 11 @ y] is [(x @ 00 @ y) + 10001100] *) + + type constant = Z.t + + type atom = X.r + + type composite = X.r + + type t = + | Constant of constant + | Atom of atom * constant + | Composite of composite * constant + + type expr = X.r + + let normal_form r = + let rec loop cte rev_acc = function + | [] -> ( + match rev_acc with + | [] -> + Constant cte + | [ { Bitv.bv = Bitv.Other { value ; negated = false }; _ } ] -> + Atom (value, cte) + | _ -> + Composite (Shostak.Bitv.is_mine (List.rev rev_acc), cte) + ) + | { Bitv.bv = Bitv.Cte n ; sz } :: bv' -> + let cte = Z.(cte lsl sz lor n) in + let acc = + match rev_acc with + | [] -> [] + | _ -> { Bitv.bv = Bitv.Cte Z.zero ; sz } :: rev_acc + in + loop cte acc bv' + | x :: bv' -> + let cte = Z.(cte lsl x.sz) in + loop cte (x :: rev_acc) bv' + in loop Z.zero [] (Shostak.Bitv.embed r) +end module Constraint : sig - include Rel_utils.Constraint + type binop = + (* Bitwise operations *) + | Band | Bor | Bxor + (* Arithmetic operations *) + | Badd | Bmul | Budiv | Burem + (* Shift operations *) + | Bshl | Blshr + + type fun_t = + | Fbinop of binop * X.r * X.r + + type binrel = Rule | Rugt + + type rel_t = + | Rbinrel of binrel * X.r * X.r + + type view = + | Cfun of X.r * fun_t + | Crel of rel_t + + type t + + val view : t -> view + + val pp : t Fmt.t + (** Pretty-printer for constraints. *) val equal : t -> t -> bool val hash : t -> int + val compare : t -> t -> int + (** Comparison function for constraints. The comparison function is + arbitrary and has no semantic meaning. You should not depend on any of + its properties, other than it defines an (arbitrary) total order on + constraint representations. *) + + val fold_args : (X.r -> 'a -> 'a) -> t -> 'a -> 'a + val bvand : X.r -> X.r -> X.r -> t (** [bvand x y z] is the constraint [x = y & z] *) @@ -268,14 +275,6 @@ module Constraint : sig val bvule : X.r -> X.r -> t val bvugt : X.r -> X.r -> t - - val propagate_bitlist : ex:Ex.t -> t -> Bitlist_domains.Ephemeral.t -> unit - (** [propagate ~ex t dom] propagates the constraint [t] in domain [dom]. - - The explanation [ex] justifies that the constraint [t] applies, and must - be added to any domain that gets updated during propagation. *) - - val propagate_interval : ex:Ex.t -> t -> Interval_domains.Ephemeral.t -> unit end = struct type binop = (* Bitwise operations *) @@ -330,9 +329,327 @@ end = struct | Band | Bor | Bxor | Badd | Bmul -> true | Budiv | Burem | Bshl | Blshr -> false + type fun_t = + | Fbinop of binop * X.r * X.r + + let pp_fun_t ppf = function + | Fbinop (op, x, y) -> + Fmt.pf ppf "%a@[(%a,@ %a)@]" pp_binop op X.print x X.print y + + let equal_fun_t f1 f2 = + match f1, f2 with + | Fbinop (op1, x1, y1), Fbinop (op2, x2, y2) -> + equal_binop op1 op2 && X.equal x1 x2 && X.equal y1 y2 + + let hash_fun_t = function + | Fbinop (op, x, y) -> Hashtbl.hash (hash_binop op, X.hash x, X.hash y) + + let normalize_fun_t = function + | Fbinop (op, x, y) when is_commutative op && X.hash_cmp x y > 0 -> + Fbinop (op, y, x) + | Fbinop _ as e -> e + + type binrel = Rule | Rugt + + let pp_binrel ppf = function + | Rule -> Fmt.pf ppf "bvule" + | Rugt -> Fmt.pf ppf "bvugt" + + let equal_binrel : binrel -> binrel -> bool = Stdlib.(=) + + let hash_binrel : binrel -> int = Hashtbl.hash + + type rel_t = + | Rbinrel of binrel * X.r * X.r + + let pp_rel_t ppf = function + | Rbinrel (op, x, y) -> + Fmt.pf ppf "%a@[(%a,@ %a)@]" pp_binrel op X.print x X.print y + + let equal_rel_t f1 f2 = + match f1, f2 with + | Rbinrel (op1, x1, y1), Rbinrel (op2, x2, y2) -> + equal_binrel op1 op2 && X.equal x1 x2 && X.equal y1 y2 + + let hash_rel_t = function + | Rbinrel (op, x, y) -> Hashtbl.hash (hash_binrel op, X.hash x, X.hash y) + + let normalize_rel_t = function + (* No normalization for relations yet *) + | r -> r + + type view = + | Cfun of X.r * fun_t + | Crel of rel_t + + let pp_view ppf = function + | Cfun (r, fn) -> + Fmt.(pf ppf "%a =@ %a" (box X.print) r (box pp_fun_t) fn) + | Crel rel -> + pp_rel_t ppf rel + + let equal_view c1 c2 = + match c1, c2 with + | Cfun (r1, f1), Cfun (r2, f2) -> + X.equal r1 r2 && equal_fun_t f1 f2 + | Cfun _, _ | _, Cfun _ -> false + + | Crel r1, Crel r2 -> + equal_rel_t r1 r2 + + let hash_view = function + | Cfun (r, f) -> Hashtbl.hash (0, X.hash r, hash_fun_t f) + | Crel r -> Hashtbl.hash (1, hash_rel_t r) + + let normalize_view = function + | Cfun (r, f) -> Cfun (r, normalize_fun_t f) + | Crel r -> Crel (normalize_rel_t r) + + type t = { view : view ; mutable tag : int } + + let view { view ; _ } = view + + let pp ppf { view; _ } = pp_view ppf view + + module W = Weak.Make(struct + type nonrec t = t + + let equal c1 c2 = equal_view c1.view c2.view + + let hash c = hash_view c.view + end) + + let hcons = + let cnt = ref 0 in + let tbl = W.create 17 in + fun view -> + let view = normalize_view view in + let tagged = W.merge tbl { view ; tag = -1 } in + if tagged.tag = -1 then ( + tagged.tag <- !cnt; + incr cnt + ); + tagged + + let cfun r f = hcons @@ Cfun (r, f) + + let cbinop op r x y = cfun r (Fbinop (op, x, y)) + + let bvand = cbinop Band + let bvor = cbinop Bor + let bvxor = cbinop Bxor + let bvadd = cbinop Badd + let bvsub r x y = + (* r = x - y <-> x = r + y *) + bvadd x r y + let bvmul = cbinop Bmul + let bvudiv = cbinop Budiv + let bvurem = cbinop Burem + let bvshl = cbinop Bshl + let bvlshr = cbinop Blshr + + let crel r = hcons @@ Crel r + + let cbinrel op x y = crel (Rbinrel (op, x, y)) + + let bvule = cbinrel Rule + let bvugt = cbinrel Rugt + + let equal c1 c2 = c1.tag = c2.tag + + let hash c = Hashtbl.hash c.tag + + let compare c1 c2 = Int.compare c1.tag c2.tag + + let fold_args_fun_t f fn acc = + match fn with + | Fbinop (_, x, y) -> f y (f x acc) + + let fold_args_rel_t f r acc = + match r with + | Rbinrel (_op, x, y) -> f y (f x acc) + + let fold_args_view f c acc = + match c with + | Cfun (r, fn) -> fold_args_fun_t f fn (f r acc) + | Crel r -> fold_args_rel_t f r acc + + let fold_args f c acc = fold_args_view f (view c) acc +end + + +module EC = ExplainedOrdered(struct + include Constraint + + module Set = Set.Make(Constraint) + module Map = Map.Make(Constraint) + end) + +module CompositeIntervalDomain = struct + type var = X.r + + type atom = X.r + + type domain = Interval_domain.t + + let map_signed f { Bitv.value; negated } sz = + if negated then Interval_domain.lognot sz (f value) else f value + + let map_domain f r = + List.fold_left (fun ival { Bitv.bv; sz } -> + let ival = Intervals.Int.scale Z.(~$1 lsl sz) ival in + Intervals.Int.add ival @@ + match bv with + | Bitv.Cte z -> Interval_domain.constant z + | Other s -> map_signed f s sz + | Ext (s, sz', i, j) -> + Intervals.Int.extract (map_signed f s sz') ~ofs:i ~len:(j - i + 1) + ) (Interval_domain.constant Z.zero) (Shostak.Bitv.embed r) +end + +module XComposite = struct + include Rel_utils.XComparable + + type atom = X.r + + let fold f r acc = + List.fold_left (fun acc { Bitv.bv ; _ } -> + match bv with + | Bitv.Cte _ -> acc + | Other { value ; _ } -> f value acc + | Ext ({ value ; _ }, _, _, _) -> f value acc + ) acc (Shostak.Bitv.embed r) +end + +module XAtom = struct + include Rel_utils.XComparable + + let type_info = X.type_info +end + +module Interval_domains = + Rel_utils.Domains_make + (Interval_domain) + (XAtom) + (XComposite) + (CompositeIntervalDomain) + (BitvNormalForm) + (EC) + +module Interval_domains_uf = + Rel_utils.UfHandle + (Interval_domain) + (Interval_domains.Ephemeral) + +module Bitlist_domain = struct + (* Note: these functions are not in [Bitlist] proper in order to avoid a + (direct) dependency from [Bitlist] to the [Shostak] module. *) + + include Bitlist + + type constant = Z.t + + let constant n = exact n Ex.empty + + let filter_ty = is_bv_ty + + let unknown = function + | Ty.Tbitv n -> extract unknown 0 n + | _ -> + (* Only bit-vector values can have bitlist domains. *) + invalid_arg "unknown" + + let add_offset d cte = + Bitlist.logor d (Bitlist.exact cte Explanation.empty) + + let sub_offset d cte = + let cte = Bitlist.exact cte Explanation.empty in + Bitlist.logand d (Bitlist.lognot cte) +end + +module CompositeBitlistDomain = struct + type var = X.r + + type atom = X.r + + type domain = Bitlist_domain.t + + let map_signed sz f { Bitv.value; negated } = + let bl = f value in + if negated then Bitlist.extract (Bitlist.lognot bl) 0 sz else bl + + let map_domain f r = + List.fold_left (fun bl { Bitv.bv; sz } -> + let open Bitlist in + bl lsl sz lor + match bv with + | Bitv.Cte z -> extract (Bitlist_domain.constant z) 0 sz + | Other r -> map_signed sz f r + | Ext (r, r_sz, i, j) -> + extract (map_signed r_sz f r) i (j - i + 1) + ) (Bitlist_domain.constant Z.zero) (Shostak.Bitv.embed r) +end + +module Bitlist_domains = + Rel_utils.Domains_make + (Bitlist_domain) + (XAtom) + (XComposite) + (CompositeBitlistDomain) + (BitvNormalForm) + (EC) + +module Bitlist_domains_uf = + Rel_utils.UfHandle + (Bitlist_domain) + (Bitlist_domains.Ephemeral) + +(** The ['c acts] type is used to register new facts and constraints in + [Propagator.simplify]. *) +type 'c acts = + { acts_add_lit_view : ex:Explanation.t -> X.r L.view -> unit + (** Assert a semantic literal. *) + ; acts_add_eq : ex:Explanation.t -> X.r -> X.r -> unit + (** Assert equality between two semantic values. *) + ; acts_add_constraint : ex:Explanation.t -> 'c -> unit + (** Assert a new constraint. *) + } + +module Propagator : sig + type t = Constraint.t + (** The type of constraints. + + Constraints apply to semantic values of type [X.r] as arguments. *) + + val simplify : Uf.t -> t -> t acts -> bool + (** [simplify c acts] simplifies the constraint [c] by calling appropriate + functions on [acts]. + + {b Note}: All the facts and constraints added through [acts] must be + logically implied by [c] {b only}. Doing otherwise is a {b soundness bug}. + + Returns [true] if the constraint has been fully simplified and can + be removed, and [false] otherwise. + + {b Note}: Returning [true] will cause the constraint to be removed, even + if it was re-added with [acts_add_constraint]. If you want to add new + facts/constraints but keep the existing constraint (usually a bad idea), + return [false] instead. *) + + val propagate_bitlist : Bitlist_domains_uf.t -> ex:Ex.t -> t -> unit + (** [propagate dom ~ex t] propagates the constraint [t] in domain [dom]. + + The explanation [ex] justifies that the constraint [t] applies, and must + be added to any domain that gets updated during propagation. *) + + val propagate_interval : + Interval_domains_uf.t -> ex:Ex.t -> t -> unit +end = struct + include Constraint + let propagate_binop ~ex sz dx op dy dz = - let open Bitlist_domains.Ephemeral in let norm bl = Bitlist.extract bl 0 sz in + let open Rel_utils.HandleNotations(Bitlist_domain)(Bitlist_domains_uf) in match op with | Band -> update ~ex dx @@ norm @@ Bitlist.logand !!dy !!dz; @@ -383,7 +700,7 @@ end = struct () let propagate_interval_binop ~ex sz dr op dx dy = - let open Interval_domains.Ephemeral in + let open Rel_utils.HandleNotations(Interval_domain)(Interval_domains_uf) in let norm i = Intervals.Int.extract i ~ofs:0 ~len:sz in match op with | Badd -> @@ -410,58 +727,20 @@ end = struct (* No interval propagation for bitwise operators yet *) () - type fun_t = - | Fbinop of binop * X.r * X.r - - let pp_fun_t ppf = function - | Fbinop (op, x, y) -> - Fmt.pf ppf "%a@[(%a,@ %a)@]" pp_binop op X.print x X.print y - - let equal_fun_t f1 f2 = - match f1, f2 with - | Fbinop (op1, x1, y1), Fbinop (op2, x2, y2) -> - equal_binop op1 op2 && X.equal x1 x2 && X.equal y1 y2 - - let hash_fun_t = function - | Fbinop (op, x, y) -> Hashtbl.hash (hash_binop op, X.hash x, X.hash y) - - let normalize_fun_t = function - | Fbinop (op, x, y) when is_commutative op && X.hash_cmp x y > 0 -> - Fbinop (op, y, x) - | Fbinop _ as e -> e - - let fold_args_fun_t f fn acc = - match fn with - | Fbinop (_, x, y) -> f y (f x acc) - - let subst_fun_t rr nrr = function - | Fbinop (op, x, y) -> Fbinop (op, X.subst rr nrr x, X.subst rr nrr y) - let propagate_fun_t ~ex dom r f = - let open Bitlist_domains.Ephemeral in - let get r = handle dom r in + let get r = Bitlist_domains_uf.entry dom r in match f with | Fbinop (op, x, y) -> let n = bitwidth r in propagate_binop ~ex n (get r) op (get x) (get y) let propagate_interval_fun_t ~ex dom r f = - let get r = Interval_domains.Ephemeral.handle dom r in + let get r = Interval_domains_uf.entry dom r in match f with | Fbinop (op, x, y) -> let sz = bitwidth r in propagate_interval_binop ~ex sz (get r) op (get x) (get y) - type binrel = Rule | Rugt - - let pp_binrel ppf = function - | Rule -> Fmt.pf ppf "bvule" - | Rugt -> Fmt.pf ppf "bvugt" - - let equal_binrel : binrel -> binrel -> bool = Stdlib.(=) - - let hash_binrel : binrel -> int = Hashtbl.hash - let propagate_binrel ~ex:_ _op _dx _dy = (* No bitlist propagation for relations yet *) () @@ -477,7 +756,7 @@ end = struct Intervals.Int.of_bounds ~ex:(Ex.union ex ex') inf Unbounded let propagate_less_than ~ex ~strict dx dy = - let open Interval_domains.Ephemeral in + let open Rel_utils.HandleNotations(Interval_domain)(Interval_domains_uf) in (* Add justification prior to calling [update] to ensure that it is only stored on the appropriate bound. *) update ~ex:Ex.empty dx (less_than_sup ~ex ~strict !!dy); @@ -490,153 +769,31 @@ end = struct | Rugt -> propagate_less_than ~ex ~strict:true dy dx - type rel_t = - | Rbinrel of binrel * X.r * X.r - - let pp_rel_t ppf = function - | Rbinrel (op, x, y) -> - Fmt.pf ppf "%a@[(%a,@ %a)@]" pp_binrel op X.print x X.print y - - let equal_rel_t f1 f2 = - match f1, f2 with - | Rbinrel (op1, x1, y1), Rbinrel (op2, x2, y2) -> - equal_binrel op1 op2 && X.equal x1 x2 && X.equal y1 y2 - - let hash_rel_t = function - | Rbinrel (op, x, y) -> Hashtbl.hash (hash_binrel op, X.hash x, X.hash y) - - let normalize_rel_t = function - (* No normalization for relations yet *) - | r -> r - - let fold_args_rel_t f r acc = - match r with - | Rbinrel (_op, x, y) -> f y (f x acc) - - let subst_rel_t rr nrr = function - | Rbinrel (op, x, y) -> Rbinrel (op, X.subst rr nrr x, X.subst rr nrr y) - let propagate_rel_t ~ex dom r = - let open Bitlist_domains.Ephemeral in - let get r = handle dom r in + let get r = Bitlist_domains_uf.entry dom r in match r with | Rbinrel (op, x, y) -> propagate_binrel ~ex op (get x) (get y) let propagate_interval_rel_t ~ex dom r = - let get r = Interval_domains.Ephemeral.handle dom r in + let get r = Interval_domains_uf.entry dom r in match r with | Rbinrel (op, x, y) -> propagate_interval_binrel ~ex op (get x) (get y) - type repr = - | Cfun of X.r * fun_t - | Crel of rel_t - - let pp_repr ppf = function - | Cfun (r, fn) -> - Fmt.(pf ppf "%a =@ %a" (box X.print) r (box pp_fun_t) fn) - | Crel rel -> - pp_rel_t ppf rel - - let equal_repr c1 c2 = - match c1, c2 with - | Cfun (r1, f1), Cfun (r2, f2) -> - X.equal r1 r2 && equal_fun_t f1 f2 - | Cfun _, _ | _, Cfun _ -> false - - | Crel r1, Crel r2 -> - equal_rel_t r1 r2 - - let hash_repr = function - | Cfun (r, f) -> Hashtbl.hash (0, X.hash r, hash_fun_t f) - | Crel r -> Hashtbl.hash (1, hash_rel_t r) - - let normalize_repr = function - | Cfun (r, f) -> Cfun (r, normalize_fun_t f) - | Crel r -> Crel (normalize_rel_t r) - - let fold_args_repr f c acc = - match c with - | Cfun (r, fn) -> fold_args_fun_t f fn (f r acc) - | Crel r -> fold_args_rel_t f r acc - - let subst_repr rr nrr = function - | Cfun (r, f) -> Cfun (X.subst rr nrr r, subst_fun_t rr nrr f) - | Crel r -> Crel (subst_rel_t rr nrr r) - - let propagate_repr ~ex dom = function + let propagate_view ~ex dom = function | Cfun (r, f) -> propagate_fun_t ~ex dom r f | Crel r -> propagate_rel_t ~ex dom r - let propagate_interval_repr ~ex dom = function + let propagate_interval_view ~ex dom = function | Cfun (r, f) -> propagate_interval_fun_t ~ex dom r f | Crel r -> propagate_interval_rel_t ~ex dom r - type t = { repr : repr ; mutable tag : int } - - let pp ppf { repr; _ } = pp_repr ppf repr - - module W = Weak.Make(struct - type nonrec t = t - - let equal c1 c2 = equal_repr c1.repr c2.repr - - let hash c = hash_repr c.repr - end) - - let hcons = - let cnt = ref 0 in - let tbl = W.create 17 in - fun repr -> - let repr = normalize_repr repr in - let tagged = W.merge tbl { repr ; tag = -1 } in - if tagged.tag = -1 then ( - tagged.tag <- !cnt; - incr cnt - ); - tagged - - let cfun r f = hcons @@ Cfun (r, f) - - let cbinop op r x y = cfun r (Fbinop (op, x, y)) - - let bvand = cbinop Band - let bvor = cbinop Bor - let bvxor = cbinop Bxor - let bvadd = cbinop Badd - let bvsub r x y = - (* r = x - y <-> x = r + y *) - bvadd x r y - let bvmul = cbinop Bmul - let bvudiv = cbinop Budiv - let bvurem = cbinop Burem - let bvshl = cbinop Bshl - let bvlshr = cbinop Blshr - - let crel r = hcons @@ Crel r - - let cbinrel op x y = crel (Rbinrel (op, x, y)) - - let bvule = cbinrel Rule - let bvugt = cbinrel Rugt - - let equal c1 c2 = c1.tag = c2.tag - - let hash c = Hashtbl.hash c.tag - - let compare c1 c2 = Int.compare c1.tag c2.tag - - let fold_args f c acc = fold_args_repr f c.repr acc - - let subst rr nrr c = - hcons @@ subst_repr rr nrr c.repr + let propagate_bitlist dom ~ex c = + propagate_view ~ex dom (view c) - let propagate_bitlist ~ex c dom = - propagate_repr ~ex dom c.repr - - let propagate_interval ~ex c dom = - propagate_interval_repr ~ex dom c.repr + let propagate_interval dom ~ex c = + propagate_interval_view ~ex dom (view c) let const sz n = Shostak.Bitv.is_mine [ { bv = Cte (Z.extract n 0 sz); sz } ] @@ -652,59 +809,59 @@ end = struct | _ -> invalid_arg "const_value" (* Add the constraint: r = x *) - let add_eq acts r x = - acts.Rel_utils.acts_add_eq r x + let add_eq ~ex acts r x = + acts.acts_add_eq ~ex r x (* Add the constraint: r = c *) - let add_eq_const acts r c = - add_eq acts r @@ const (bitwidth r) c + let add_eq_const ~ex acts r c = + add_eq ~ex acts r @@ const (bitwidth r) c (* Add the constraint: r = x & c *) - let add_and_const acts r x c = + let add_and_const ~ex acts r x c = (* TODO: apply to extractions for any [c]? Could be expensive for Shostak *) if Z.equal c Z.zero then ( - add_eq_const acts r Z.zero; + add_eq_const ~ex acts r Z.zero; true ) else if Z.equal c (Z.extract Z.minus_one 0 (bitwidth r)) then ( - add_eq acts r x; + add_eq ~ex acts r x; true ) else false (* Add the constraint: r = x | c *) - let add_or_const acts r x c = + let add_or_const ~ex acts r x c = (* TODO: apply to extractions for any [c]? Could be expensive for Shostak *) if Z.equal c Z.zero then ( - add_eq acts r x; + add_eq ~ex acts r x; true ) else if Z.equal c (Z.extract Z.minus_one 0 (bitwidth r)) then ( - add_eq_const acts r Z.minus_one; + add_eq_const ~ex acts r Z.minus_one; true ) else false (* Add the constraint: r = x ^ c *) - let add_xor_const acts r x c = + let add_xor_const ~ex acts r x c = (* TODO: apply to extractions for any [c]? Could be expensive for Shostak *) if Z.equal c Z.zero then ( - add_eq acts r x; + add_eq ~ex acts r x; true ) else if Z.equal c (Z.extract Z.minus_one 0 (bitwidth r)) then ( - add_eq acts r + add_eq ~ex acts r (Shostak.Bitv.is_mine @@ Bitv.lognot @@ Shostak.Bitv.embed x); true ) else false (* Add the constraint: r = x + c *) - let add_add_const acts r x c = + let add_add_const ~ex acts r x c = let sz = bitwidth r in if Z.equal c Z.zero then ( - add_eq acts r x; + add_eq ~ex acts r x; true ) else if X.is_constant r then ( (* c1 = x + c2 -> x = c1 - c2 *) - add_eq_const acts x Z.(value r - c); + add_eq_const ~ex acts x Z.(value r - c); true ) else if Z.testbit c (sz - 1) then (* Due to the modular nature of arithmetic on bit-vectors, [y = x + c] @@ -719,7 +876,7 @@ end = struct are actually equivalent, so we just pick a normalized order between x and r. *) if X.hash_cmp r x > 0 then ( - acts.acts_add_constraint (bvadd x r (const (bitwidth r) c)); + acts.acts_add_constraint ~ex (bvadd x r (const (bitwidth r) c)); true ) else false @@ -727,16 +884,16 @@ end = struct (* r = x - c -> x = r + c (mod 2^sz) *) let c = Z.neg @@ Z.signed_extract c 0 sz in assert (Z.sign c > 0 && not (Z.testbit c sz)); - acts.acts_add_constraint (bvadd x r (const sz c)); + acts.acts_add_constraint ~ex (bvadd x r (const sz c)); true else false (* Add the constraint: r = x << c *) - let add_shl_const acts r x c = + let add_shl_const ~ex acts r x c = let sz = bitwidth r in match Z.to_int c with - | 0 -> add_eq acts r x + | 0 -> add_eq ~ex acts r x | n when n < sz -> assert (n > 0); let r_bitv = Shostak.Bitv.embed r in @@ -744,32 +901,32 @@ end = struct Shostak.Bitv.is_mine @@ Bitv.extract sz 0 (sz - 1 - n) (Shostak.Bitv.embed x) in - add_eq acts + add_eq ~ex acts (Shostak.Bitv.is_mine @@ Bitv.extract sz n (sz - 1) r_bitv) high_bits; - add_eq_const acts + add_eq_const ~ex acts (Shostak.Bitv.is_mine @@ Bitv.extract sz 0 (n - 1) r_bitv) Z.zero | _ | exception Z.Overflow -> - add_eq_const acts r Z.zero + add_eq_const ~ex acts r Z.zero (* Add the constraint: r = x * c *) - let add_mul_const acts r x c = + let add_mul_const ~ex acts r x c = if Z.equal c Z.zero then ( - add_eq_const acts r Z.zero; + add_eq_const ~ex acts r Z.zero; true ) else if Z.popcount c = 1 then ( let ofs = Z.numbits c - 1 in - add_shl_const acts r x (Z.of_int ofs); + add_shl_const ~ex acts r x (Z.of_int ofs); true ) else false (* Add the constraint: r = x >> c *) - let add_lshr_const acts r x c = + let add_lshr_const ~ex acts r x c = let sz = bitwidth r in match Z.to_int c with - | 0 -> add_eq acts r x + | 0 -> add_eq ~ex acts r x | n when n < sz -> assert (n > 0); let r_bitv = Shostak.Bitv.embed r in @@ -777,14 +934,14 @@ end = struct Shostak.Bitv.is_mine @@ Bitv.extract sz n (sz - 1) (Shostak.Bitv.embed x) in - add_eq acts + add_eq ~ex acts (Shostak.Bitv.is_mine @@ Bitv.extract sz 0 (sz - 1 - n) r_bitv) low_bits; - add_eq_const acts + add_eq_const ~ex acts (Shostak.Bitv.is_mine @@ Bitv.extract sz (sz - n) (sz - 1) r_bitv) Z.zero | _ | exception Z.Overflow -> - add_eq_const acts r Z.zero + add_eq_const ~ex acts r Z.zero (* Ground evaluation rules for binary operators. *) let eval_binop op ty x y = @@ -823,117 +980,123 @@ end = struct evaluated is assumed to be dealt with prior to calling this function. Algebraic rules (e.g. x & x = x) are in [rw_binop_algebraic].*) - let rw_binop_const acts op r x y = + let rw_binop_const ~ex acts op r x y = (* NB: for commutative operators, arguments are sorted, so the second argument can only be constant if the first argument also is constant. *) match op with | Band when X.is_constant x -> - add_and_const acts r y (value x) + add_and_const ~ex acts r y (value x) | Band when X.is_constant y -> - add_and_const acts r x (value y) + add_and_const ~ex acts r x (value y) | Band -> false | Bor when X.is_constant x -> - add_or_const acts r y (value x) + add_or_const ~ex acts r y (value x) | Bor when X.is_constant y -> - add_or_const acts r x (value y) + add_or_const ~ex acts r x (value y) | Bor -> false | Bxor when X.is_constant x -> - add_xor_const acts r y (value x) + add_xor_const ~ex acts r y (value x) | Bxor when X.is_constant y -> - add_xor_const acts r x (value y) + add_xor_const ~ex acts r x (value y) | Bxor when X.is_constant r -> - add_xor_const acts x y (value r) + add_xor_const ~ex acts x y (value r) | Bxor -> false | Badd when X.is_constant x -> - add_add_const acts r y (value x) + add_add_const ~ex acts r y (value x) | Badd when X.is_constant y -> - add_add_const acts r x (value y) + add_add_const ~ex acts r x (value y) | Badd -> false | Bmul when X.is_constant x -> - add_mul_const acts r y (value x) + add_mul_const ~ex acts r y (value x) | Bmul when X.is_constant y -> - add_mul_const acts r x (value y) + add_mul_const ~ex acts r x (value y) | Bmul -> false | Budiv | Burem -> false (* shifts becomes a simple extraction when we know the right-hand side *) | Bshl when X.is_constant y -> - add_shl_const acts r x (value y); + add_shl_const ~ex acts r x (value y); true | Bshl -> false | Blshr when X.is_constant y -> - add_lshr_const acts r x (value y); + add_lshr_const ~ex acts r x (value y); true | Blshr -> false (* Algebraic rewrite rules for binary operators. Rules based on constant simplifications are in [rw_binop_const]. *) - let rw_binop_algebraic acts op r x y = + let rw_binop_algebraic ~ex acts op r x y = match op with (* x & x = x ; x | x = x *) | Band | Bor when X.equal x y -> - add_eq acts r x; true + add_eq ~ex acts r x; true (* r ^ x ^ x = 0 <-> r = 0 *) | Bxor when X.equal x y -> - add_eq_const acts r Z.zero; true + add_eq_const ~ex acts r Z.zero; true | Bxor when X.equal r x -> - add_eq_const acts y Z.zero; true + add_eq_const ~ex acts y Z.zero; true | Bxor when X.equal r y -> - add_eq_const acts x Z.zero; true + add_eq_const ~ex acts x Z.zero; true | Badd when X.equal x y -> (* r = x + x -> r = 2x -> r = x << 1 *) - add_shl_const acts r x Z.one; true + add_shl_const ~ex acts r x Z.one; true | Badd when X.equal r x -> (* x = x + y -> y = 0 *) - add_eq_const acts y Z.zero; true + add_eq_const ~ex acts y Z.zero; true | Badd when X.equal r y -> (* y = x + y -> x = 0 *) - add_eq_const acts x Z.zero; true + add_eq_const ~ex acts x Z.zero; true | _ -> false - let simplify_binop acts op r x y = + let simplify_binop ~ex acts op r x y = if X.is_constant x && X.is_constant y then ( - add_eq acts r @@ + add_eq ~ex acts r @@ eval_binop op (X.type_info r) (value x) (value y); true ) else - rw_binop_const acts op r x y || - rw_binop_algebraic acts op r x y + rw_binop_const ~ex acts op r x y || + rw_binop_algebraic ~ex acts op r x y - let simplify_fun_t acts r = function - | Fbinop (op, x, y) -> simplify_binop acts op r x y + let simplify_fun_t uf acts r = function + | Fbinop (op, x, y) -> + let r, ex_r = Uf.find_r uf r in + let x, ex_x = Uf.find_r uf x in + let y, ex_y = Uf.find_r uf y in + let ex = Explanation.union ex_r (Explanation.union ex_x ex_y) in + simplify_binop ~ex acts op r x y - let simplify_binrel acts op x y = + let simplify_binrel ~ex acts op x y = match op with | Rugt when X.equal x y -> - acts.Rel_utils.acts_add_eq X.top X.bot; + acts.acts_add_eq ~ex X.top X.bot; true | Rule | Rugt -> false - let simplify_rel_t acts = function - | Rbinrel (op, x, y) -> simplify_binrel acts op x y + let simplify_rel_t uf acts = function + | Rbinrel (op, x, y) -> + let x, ex_x = Uf.find_r uf x in + let y, ex_y = Uf.find_r uf y in + simplify_binrel ~ex:(Explanation.union ex_x ex_y) acts op x y - let simplify_repr acts = function - | Cfun (r, f) -> simplify_fun_t acts r f - | Crel r -> simplify_rel_t acts r + let simplify_view uf acts = function + | Cfun (r, f) -> simplify_fun_t uf acts r f + | Crel r -> simplify_rel_t uf acts r - let simplify c acts = - simplify_repr acts c.repr + let simplify uf c acts = + simplify_view uf acts (view c) end -module Constraints = Rel_utils.Constraints_make(Constraint) - let extract_binop = let open Constraint in function | Sy.BVand -> Some bvand @@ -948,18 +1111,33 @@ let extract_binop = | BVlshr -> Some bvlshr | _ -> None -let extract_constraints bcs uf r t = +let extract_term r terms = + if X.is_a_leaf r then SX.add r terms + else terms + +let extract_constraints terms domain int_domain uf r t = match E.term_view t with | { f = Op op; xs = [ x; y ]; _ } -> ( match extract_binop op with | Some mk -> let rx, exx = Uf.find uf x and ry, exy = Uf.find uf y in - Constraints.add - ~ex:(Ex.union exx exy) (mk r rx ry) bcs - | _ -> bcs + let c = mk r rx ry in + let ex = Ex.union exx exy in + let domain = + Bitlist_domains.watch (explained ~ex c) rx @@ + Bitlist_domains.watch (explained ~ex c) ry @@ + domain + in + let int_domain = + Interval_domains.watch (explained ~ex c) rx @@ + Interval_domains.watch (explained ~ex c) ry @@ + int_domain + in + terms, domain, int_domain + | None -> extract_term r terms, domain, int_domain ) - | _ -> bcs + | _ -> extract_term r terms, domain, int_domain let rec mk_eq ex lhs w z = match lhs with @@ -1010,7 +1188,7 @@ let add_eqs = module Any_constraint = struct type t = - | Constraint of Constraint.t Rel_utils.explained + | Constraint of Constraint.t explained | Structural of X.r (** Structural constraint associated with [X.r]. See {!Rel_utils.Bitlist_domains.structural_propagation}. *) @@ -1025,17 +1203,27 @@ module Any_constraint = struct | Constraint c -> 2 * Constraint.hash c.value | Structural r -> 2 * X.hash r + 1 - let propagate constraint_propagate structural_propagation c d = + let propagate constraint_propagate structural_propagation c = Steps.incr CP; match c with | Constraint { value; explanation = ex } -> - constraint_propagate ~ex value d + constraint_propagate ~ex value | Structural r -> - structural_propagation d r + structural_propagation r end module QC = Uqueue.Make(Any_constraint) +let propagate_queue queue constraint_propagate structural_propagation = + try + while true do + Any_constraint.propagate + constraint_propagate + structural_propagation + (QC.pop queue) + done + with QC.Empty -> () + let finite_lower_bound = function | Intervals_intf.Unbounded -> Z.zero | Closed n -> n @@ -1063,8 +1251,9 @@ let finite_upper_bound ~size:sz = function five most-significant bits, denoted [00110???]. Therefore, a bit-vector bl = [0??1???0] can be refined into [00110??0]. *) let constrain_bitlist_from_interval ~size:sz bv int = - let open Bitlist_domains.Ephemeral in - + let + open Rel_utils.HandleNotations(Bitlist_domain)(Bitlist_domains.Ephemeral) + in let inf, inf_ex = Intervals.Int.lower_bound int in let inf = finite_lower_bound inf in let sup, sup_ex = Intervals.Int.upper_bound int in @@ -1094,7 +1283,9 @@ let constrain_bitlist_from_interval ~size:sz bv int = [Bitlist.decrease_upper_bound] on all the constituent intervals of an union; see the documentation of these functions for details. *) let constrain_interval_from_bitlist ~size:sz int bv = - let open Interval_domains.Ephemeral in + let + open Rel_utils.HandleNotations(Interval_domain)(Interval_domains.Ephemeral) + in let ex = Bitlist.explanation bv in (* Handy wrapper around [of_complement] *) let remove ~ex i2 i1 = @@ -1144,126 +1335,252 @@ let constrain_interval_from_bitlist ~size:sz int bv = acc ) !!int !!int -let propagate_bitlist queue touched bcs dom = - let touch_c c = QC.push queue (Constraint c) in - let touch r = - HX.replace touched r (); - QC.push queue (Structural r); - Constraints.iter_parents touch_c r bcs - in - try - while true do - Bitlist_domains.Ephemeral.iter_changed touch dom; - Bitlist_domains.Ephemeral.clear_changed dom; - Any_constraint.propagate - Constraint.propagate_bitlist - Bitlist_domains.Ephemeral.structural_propagation - (QC.pop queue) dom - done - with QC.Empty -> () +let iter_parents a f t = + match Rel_utils.XComparable.Map.find a t with + | cs -> Rel_utils.XComparable.Set.iter f cs + | exception Not_found -> () + +let propagate_bitlist queue vars dom = + let structural_propagation r = + let open Rel_utils.HandleNotations(Bitlist_domain)(Bitlist_domains_uf) in + let get r = !!(Bitlist_domains_uf.entry dom r) in + let update r d = + update ~ex:Explanation.empty (Bitlist_domains_uf.entry dom r) d + in + if X.is_a_leaf r then + iter_parents r (fun p -> + if X.is_a_leaf p then + assert (X.equal r p) + else + update p (CompositeBitlistDomain.map_domain get p) + ) vars + else + let iter_signed sz f { Bitv.value; negated } bl = + let bl = if negated then Bitlist.(extract (lognot bl)) 0 sz else bl in + f value bl + in + ignore @@ List.fold_left (fun (bl, w) { Bitv.bv; sz } -> + (* Extract the bitlist associated with the current component *) + let mid = w - sz in + let bl_tail = Bitlist.extract bl 0 mid in + let bl = Bitlist.extract bl mid (w - mid) in -let propagate_intervals queue touched bcs dom = - let touch_c c = QC.push queue (Constraint c) in - let touch r = - HX.replace touched r (); - QC.push queue (Structural r); - Constraints.iter_parents touch_c r bcs + match bv with + | Bitv.Cte z -> + assert (Z.numbits z <= sz); + (* Nothing to update, but still check for consistency! *) + ignore @@ Bitlist.intersect bl (Bitlist.exact z Ex.empty); + bl_tail, mid + | Other r -> + iter_signed sz update r bl; + (bl_tail, mid) + | Ext (r, r_size, i, j) -> + (* r = bl -> r = ?^(r_size - j - 1) @ bl @ ?^i *) + assert (i + r_size - j - 1 + w - mid = r_size); + let hi = Bitlist.(extract unknown 0 (r_size - j - 1)) in + let lo = Bitlist.(extract unknown 0 i) in + let bl_hd = Bitlist.((hi lsl (j + 1)) lor (bl lsl i) lor lo) in + iter_signed r_size update r bl_hd; + (bl_tail, mid) + ) ((get r), (bitwidth r)) (Shostak.Bitv.embed r) in - try - while true do - Interval_domains.Ephemeral.iter_changed touch dom; - Interval_domains.Ephemeral.clear_changed dom; - Any_constraint.propagate - Constraint.propagate_interval - Interval_domains.Ephemeral.structural_propagation - (QC.pop queue) dom - done - with QC.Empty -> () - -let propagate_all eqs bcs bdom idom = - (* Call [simplify_pending] first because it can remove constraints from the - pending set. *) - let eqs, bcs = Constraints.simplify_pending eqs bcs in + propagate_queue + queue + (Propagator.propagate_bitlist dom) + structural_propagation + +let propagate_intervals queue vars dom = + let structural_propagation r = + let open Rel_utils.HandleNotations(Interval_domain)(Interval_domains_uf) in + let get r = !!(Interval_domains_uf.entry dom r) in + let update r d = + update ~ex:Explanation.empty (Interval_domains_uf.entry dom r) d + in + if X.is_a_leaf r then + iter_parents r (fun p -> + if X.is_a_leaf p then + assert (X.equal r p) + else + update p (CompositeIntervalDomain.map_domain get p) + ) vars + else + let iter_signed f { Bitv.value; negated } sz int = + f value (if negated then Interval_domain.lognot sz int else int) + in + let int = get r in + let width = bitwidth r in + let j = + List.fold_left (fun j { Bitv.bv; sz } -> + (* sz = j - i + 1 => i = j - sz + 1 *) + let int = Intervals.Int.extract int ~ofs:(j - sz + 1) ~len:sz in + + begin match bv with + | Bitv.Cte z -> + (* Nothing to update, but still check for consistency *) + ignore @@ + Interval_domain.intersect int (Interval_domain.constant z) + | Other s -> iter_signed update s sz int + | Ext (r, r_size, i, j) -> + (* r = bl -> r = ?^(r_size - j - 1) @ bl @ ?^i *) + assert (i + r_size - j - 1 + sz = r_size); + let lo = Interval_domain.unknown (Tbitv i) in + let int = Intervals.Int.scale Z.(~$1 lsl i) int in + let hi = Interval_domain.unknown (Tbitv (r_size - j - 1)) in + let hi = + Intervals.Int.scale Z.(~$1 lsl Stdlib.(i + sz)) hi + in + iter_signed update r r_size Intervals.Int.(add hi (add int lo)) + end; + + (j - sz) + ) (width - 1) (Shostak.Bitv.embed r) + in + assert (j = -1) + in + propagate_queue + queue + (Propagator.propagate_interval dom) + structural_propagation + +module HC = Hashtbl.Make(Constraint) + +let simplify_all uf eqs touched (dom, idom) = + let eqs = ref eqs in + let to_add = HC.create 17 in + let simplify c c_ex (dom, idom) = + let acts_add_lit_view ~ex l = + eqs := (l, Explanation.union ex c_ex) :: !eqs + in + let acts_add_eq ~ex u v = + acts_add_lit_view ~ex (Uf.LX.mkv_eq u v) + in + let acts_add_constraint ~ex c = + let c = explained ~ex:(Explanation.union ex c_ex) c in + HC.replace to_add c.value c.explanation + in + let acts = + { acts_add_lit_view + ; acts_add_eq + ; acts_add_constraint } in + if Propagator.simplify uf c acts then + let c = explained ~ex:c_ex c in + (Bitlist_domains.unwatch c dom, Interval_domains.unwatch c idom) + else + (dom, idom) + in + let dom, idom = HC.fold simplify touched (dom, idom) in + !eqs, + HC.fold (fun c c_ex (dom, idom) -> + let c = explained ~ex:c_ex c in + Constraint.fold_args (fun r (dom, idom) -> + let r, _ = Uf.find_r uf r in + Bitlist_domains.watch c r dom, + Interval_domains.watch c r idom + ) c.value (dom, idom) + ) to_add (dom, idom) + +let rec propagate_all uf eqs bdom idom = (* Optimization to avoid unnecessary allocations *) - let do_all = Constraints.has_pending bcs in - let do_bitlist = do_all || Bitlist_domains.has_changed bdom in - let do_intervals = do_all || Interval_domains.has_changed idom in + let do_bitlist = Bitlist_domains.needs_propagation bdom in + let do_intervals = Interval_domains.needs_propagation idom in let do_any = do_bitlist || do_intervals in if do_any then - let queue = QC.create 17 in - let touch_pending queue = - Constraints.iter_pending (fun c -> QC.push queue (Constraint c)) bcs + let shostak_candidates = HX.create 17 in + let seen_constraints = HC.create 17 in + let bitlist_queue = QC.create 17 in + let interval_queue = QC.create 17 in + let touch_c queue c = + HC.replace seen_constraints c.value c.explanation; + QC.push queue (Constraint c) + in + let touch tbl queue r = + HX.replace tbl r (); + QC.push queue (Structural r) in let bitlist_changed = HX.create 17 in - let touched = HX.create 17 in - let bdom = Bitlist_domains.edit bdom in - let idom = Interval_domains.edit idom in + let interval_changed = HX.create 17 in + let bitlist_events = + { Rel_utils.evt_atomic_change = touch bitlist_changed bitlist_queue + ; evt_composite_change = touch bitlist_changed bitlist_queue + ; evt_watch_trigger = touch_c bitlist_queue } + and interval_events = + { Rel_utils.evt_atomic_change = touch interval_changed interval_queue + ; evt_composite_change = touch interval_changed interval_queue + ; evt_watch_trigger = touch_c interval_queue } + in + let bvars = Bitlist_domains.parents bdom in + let ivars = Interval_domains.parents idom in + + let bdom = Bitlist_domains.edit ~events:bitlist_events bdom in + let idom = Interval_domains.edit ~events:interval_events idom in + + let uf_bdom = Bitlist_domains_uf.wrap uf bdom in + let uf_idom = Interval_domains_uf.wrap uf idom in (* First, we propagate the pending constraints to both domains. Changes in the bitlist domain are used to shrink the interval domains. *) - touch_pending queue; - propagate_bitlist queue touched bcs bdom; - assert (QC.is_empty queue); + propagate_bitlist bitlist_queue bvars uf_bdom; + assert (QC.is_empty bitlist_queue); - touch_pending queue; HX.iter (fun r () -> - HX.replace bitlist_changed r (); - let sz = bitwidth r in - constrain_interval_from_bitlist ~size:sz - Interval_domains.Ephemeral.(handle idom r) - Bitlist_domains.Ephemeral.(!!(handle bdom r)) - ) touched; - HX.clear touched; - propagate_intervals queue touched bcs idom; - assert (QC.is_empty queue); + HX.replace shostak_candidates r (); + constrain_interval_from_bitlist ~size:(bitwidth r) + Interval_domains.Ephemeral.(entry idom r) + Bitlist_domains.Ephemeral.(Entry.domain (entry bdom r)) + ) bitlist_changed; + HX.clear bitlist_changed; + propagate_intervals interval_queue ivars uf_idom; + assert (QC.is_empty interval_queue); (* Now the interval domain is stable, but the new intervals may have an impact on the bitlist domains, so we must shrink them again when applicable. We repeat until a fixpoint is reached. *) - let bcs = Constraints.clear_pending bcs in - while HX.length touched > 0 do + while HX.length interval_changed > 0 do HX.iter (fun r () -> - let sz = bitwidth r in - constrain_bitlist_from_interval ~size:sz - Bitlist_domains.Ephemeral.(handle bdom r) - Interval_domains.Ephemeral.(!!(handle idom r)) - ) touched; - HX.clear touched; - propagate_bitlist queue touched bcs bdom; - assert (QC.is_empty queue); + constrain_bitlist_from_interval ~size:(bitwidth r) + Bitlist_domains.Ephemeral.(entry bdom r) + Interval_domains.Ephemeral.(Entry.domain (entry idom r)) + ) interval_changed; + HX.clear interval_changed; + propagate_bitlist bitlist_queue bvars uf_bdom; + assert (QC.is_empty bitlist_queue); HX.iter (fun r () -> - let sz = bitwidth r in - HX.replace bitlist_changed r (); - constrain_interval_from_bitlist ~size:sz - Interval_domains.Ephemeral.(handle idom r) - Bitlist_domains.Ephemeral.(!!(handle bdom r)) - ) touched; - HX.clear touched; - propagate_intervals queue touched bcs idom; - assert (QC.is_empty queue); + HX.replace shostak_candidates r (); + constrain_interval_from_bitlist ~size:(bitwidth r) + Interval_domains.Ephemeral.(entry idom r) + Bitlist_domains.Ephemeral.(Entry.domain (entry bdom r)) + ) bitlist_changed; + HX.clear bitlist_changed; + propagate_intervals interval_queue ivars uf_idom; + assert (QC.is_empty interval_queue); done; let eqs = HX.fold (fun r () acc -> - let d = Bitlist_domains.Ephemeral.(!!(handle bdom r)) in - let sz = bitwidth r in - add_eqs acc (Shostak.Bitv.embed r) sz d - ) bitlist_changed eqs + let d = Bitlist_domains.Ephemeral.(Entry.domain (entry bdom r)) in + add_eqs acc (Shostak.Bitv.embed r) (bitwidth r) d + ) shostak_candidates eqs + in + + let bdom, idom = + Bitlist_domains.snapshot bdom, Interval_domains.snapshot idom in + let eqs, (bdom, idom) = simplify_all uf eqs seen_constraints (bdom, idom) in - eqs, bcs, Bitlist_domains.snapshot bdom, Interval_domains.snapshot idom + (* Propagate again in case constraints were simplified. *) + propagate_all uf eqs bdom idom else - eqs, bcs, bdom, idom + eqs, bdom, idom type t = { delayed : Rel_utils.Delayed.t - ; constraints : Constraints.t + ; terms : SX.t ; size_splits : Q.t } let empty uf = { delayed = Rel_utils.Delayed.create ~is_ready:X.is_constant dispatch - ; constraints = Constraints.empty + ; terms = SX.empty ; size_splits = Q.one }, Uf.GlobalDomains.add (module Bitlist_domains) Bitlist_domains.empty @@ Uf.GlobalDomains.add (module Interval_domains) Interval_domains.empty @@ @@ -1276,54 +1593,56 @@ let assume env uf la = Uf.GlobalDomains.find (module Interval_domains) ds in let delayed, result = Rel_utils.Delayed.assume env.delayed uf la in - let (domain, int_domain, constraints, eqs, size_splits) = + let (domain, int_domain, eqs, size_splits) = try - let (constraints, eqs, size_splits) = - List.fold_left (fun (bcs, eqs, ss) (a, _root, ex, orig) -> - let ss = - match orig with - | Th_util.CS (Th_bitv, n) -> Q.(ss * n) - | _ -> ss - in - let is_1bit r = - match X.type_info r with - | Tbitv 1 -> true - | _ -> false - in - match a, orig with - | L.Eq (rr, nrr), Subst when is_bv_r rr -> - let bcs = Constraints.subst ~ex rr nrr bcs in - (bcs, eqs, ss) - | Builtin (is_true, BVULE, [x; y]), _ -> - let x, exx = Uf.find_r uf x in - let y, exy = Uf.find_r uf y in - let ex = Ex.union ex @@ Ex.union exx exy in - let c = - if is_true then - Constraint.bvule x y - else - Constraint.bvugt x y - in - let bcs = Constraints.add ~ex c bcs in - (bcs, eqs, ss) - | L.Distinct (false, [rr; nrr]), _ when is_1bit rr -> - (* We don't (yet) support [distinct] in general, but we must - support it for case splits to avoid looping. - - We are a bit more general and support it for 1-bit vectors, for - which `distinct` can be expressed using `bvnot`. *) - let not_nrr = - Shostak.Bitv.is_mine (Bitv.lognot (Shostak.Bitv.embed nrr)) - in - (bcs, (Uf.LX.mkv_eq rr not_nrr, ex) :: eqs, ss) - | _ -> (bcs, eqs, ss) + let (domain, int_domain, eqs, size_splits) = + List.fold_left + (fun (domain, int_domain, eqs, ss) (a, _root, ex, orig) -> + let ss = + match orig with + | Th_util.CS (Th_bitv, n) -> Q.(ss * n) + | _ -> ss + in + let is_1bit r = + match X.type_info r with + | Tbitv 1 -> true + | _ -> false + in + match a, orig with + | L.Builtin (is_true, BVULE, [x; y]), _ -> + let x, exx = Uf.find_r uf x in + let y, exy = Uf.find_r uf y in + let ex = Ex.union ex @@ Ex.union exx exy in + let c = + if is_true then + Constraint.bvule x y + else + Constraint.bvugt x y + in + (* Only watch comparisons on the interval domain since we don't + propagate them in the bitlist domain. . *) + let int_domain = + Interval_domains.watch (explained ~ex c) x @@ + Interval_domains.watch (explained ~ex c) y @@ + int_domain + in + (domain, int_domain, eqs, ss) + | L.Distinct (false, [rr; nrr]), _ when is_1bit rr -> + (* We don't (yet) support [distinct] in general, but we must + support it for case splits to avoid looping. + + We are a bit more general and support it for 1-bit vectors, + for which `distinct` can be expressed using `bvnot`. *) + let not_nrr = + Shostak.Bitv.is_mine (Bitv.lognot (Shostak.Bitv.embed nrr)) + in + (domain, int_domain, (Uf.LX.mkv_eq rr not_nrr, ex) :: eqs, ss) + | _ -> (domain, int_domain, eqs, ss) ) - (env.constraints, [], env.size_splits) + (domain, int_domain, [], env.size_splits) la in - let eqs, constraints, domain, int_domain = - propagate_all eqs constraints domain int_domain - in + let eqs, domain, int_domain = propagate_all uf eqs domain int_domain in if Options.get_debug_bitv () && not (Lists.is_empty eqs) then ( Printer.print_dbg ~module_name:"Bitv_rel" ~function_name:"assume" @@ -1331,11 +1650,8 @@ let assume env uf la = Printer.print_dbg ~module_name:"Bitv_rel" ~function_name:"assume" "interval domain: @[%a@]" Interval_domains.pp int_domain; - Printer.print_dbg - ~module_name:"Bitv_rel" ~function_name:"assume" - "bitlist constraints: @[%a@]" Constraints.pp constraints; ); - (domain, int_domain, constraints, eqs, size_splits) + (domain, int_domain, eqs, size_splits) with Bitlist.Inconsistent ex | Interval_domain.Inconsistent ex -> raise @@ Ex.Inconsistent (ex, Uf.cl_extract uf) in @@ -1345,7 +1661,7 @@ let assume env uf la = let result = { result with assume = List.rev_append assume result.assume } in - { delayed ; constraints ; size_splits }, + { delayed ; size_splits ; terms = env.terms }, Uf.GlobalDomains.add (module Bitlist_domains) domain @@ Uf.GlobalDomains.add (module Interval_domains) int_domain ds, result @@ -1365,7 +1681,9 @@ let case_split env uf ~for_model = constrained variables, all the remaining variables. [nunk] is the number of unknown bits. *) - let f_acc r bl acc = + let f_acc r acc = + let r, _ = Uf.find_r uf r in + let bl = Bitlist_domains.get r domain in let nunk = Z.popcount (Bitlist.unknown_bits bl) in if nunk = 0 then acc @@ -1376,28 +1694,21 @@ let case_split env uf ~for_model = Some (nunk', SX.add r xs) | _ -> Some (nunk, SX.singleton r) in - let f_acc' r acc = - let r, _ = Uf.find_r uf r in - List.fold_left (fun acc { Bitv.bv; _ } -> - match bv with - | Bitv.Cte _ -> acc - | Other r | Ext (r, _, _, _) -> - let bl = Bitlist_domains.get r.value domain in - f_acc r.value bl acc - ) acc (Shostak.Bitv.embed r) - in let _, candidates = - match Constraints.fold_args f_acc' env.constraints None with + match SX.fold f_acc env.terms None with | Some (nunk, xs) -> nunk, xs - | _ -> - match Bitlist_domains.fold_leaves f_acc domain None with - | Some (nunk, xs) -> nunk, xs - | None -> 0, SX.empty + | None -> 0, SX.empty in (* For now, just pick a value for the most significant bit. *) match SX.choose candidates with | r -> - let bl = Bitlist_domains.get r domain in + let rr, _ = Uf.find_r uf r in + let bl = Bitlist_domains.get rr domain in + let r = + let es = Uf.rclass_of uf r in + try Uf.make uf (Expr.Set.choose es) + with Not_found -> r + in let w = bitwidth r in let unknown = Z.extract (Bitlist.unknown_bits bl) 0 w in let bitidx = Z.numbits unknown - 1 in @@ -1415,15 +1726,22 @@ let case_split env uf ~for_model = | exception Not_found -> [] let add env uf r t = + let ds = Uf.domains uf in let delayed, eqs = Rel_utils.Delayed.add env.delayed uf r t in - let env, eqs = + let env, ds, eqs = if is_bv_r r then - let constraints = extract_constraints env.constraints uf r t in - { env with constraints }, eqs + let dom = Uf.GlobalDomains.find (module Bitlist_domains) ds in + let idom = Uf.GlobalDomains.find (module Interval_domains) ds in + let terms, dom, idom = extract_constraints env.terms dom idom uf r t in + { env with terms }, + Uf.GlobalDomains.add (module Bitlist_domains) dom @@ + Uf.GlobalDomains.add (module Interval_domains) idom @@ + ds, + eqs else - env, eqs + env, ds, eqs in - { env with delayed }, Uf.domains uf, eqs + { env with delayed }, ds, eqs let optimizing_objective _env _uf _o = None diff --git a/src/lib/reasoners/rel_utils.ml b/src/lib/reasoners/rel_utils.ml index 854077ff2..fd1267e94 100644 --- a/src/lib/reasoners/rel_utils.ml +++ b/src/lib/reasoners/rel_utils.ml @@ -217,6 +217,136 @@ end = struct MX.iter (fun r -> OMap.iter (fun op -> Expr.Set.iter (f r op))) t.used_by end +module type Map_like = sig + (** Minimal signature for a persistent map type, used by [EphemeralMap]. *) + + type 'a t + + type key + + val find : key -> 'a t -> 'a + + val add : key -> 'a -> 'a t -> 'a t +end + +module type Hashtbl_like = sig + (** Minimal signature for an imperative map type, used by [EphemeralMap]. *) + + type 'a t + + type key + + val create : int -> 'a t + + val find : 'a t -> key -> 'a + + val replace : 'a t -> key -> 'a -> unit + + val fold : (key -> 'a -> 'b -> 'b) -> 'a t -> 'b -> 'b +end + +module EphemeralMap + (MX : Map_like) + (HX : Hashtbl_like with type key = MX.key) + : sig + (** This module implements an ephemeral (mutable) interface for efficient + (repeated) lookup and update to the underlying persistent map, as well + as conversion functions between persistent and ephemeral maps. *) + + type 'a t + (** The type of ephemeral maps with values of type ['a]. *) + + type key = MX.key + (** The type of keys in the ephemeral map. *) + + module Entry : sig + (** Entries associate a (mutable) content to keys in the map. *) + + type 'a t + (** The type of entries with values ['a]. *) + + val content : 'a t -> 'a + (** [content e] is the content associated with [key e] in the map. *) + + val set_content : 'a t -> 'a -> unit + (** [set_content e v] sets the content of entry [e] to [v]. This + overwrites any pre-existing content associated with [e]. *) + end + + val entry : 'a t -> key -> 'a Entry.t + (** [entry t k] returns an entry associated with key [k] in the map. + + Each key is associated with a single entry: calling [entry t k] several + times will always return the same entry. *) + + val edit : default:(key -> 'a) -> 'a MX.t -> 'a t + (** [edit ~default t] returns an ephemeral copy of [t] for edition. + + The [default] argument is used to compute a default value for missing + keys. *) + + val snapshot : 'a t -> 'a MX.t + (** [snapshot t] computes a persistent snapshot of the ephemeral map [t], + applying all the changes made using [set_content]. Entries that were + never written to using [set_content] are unchanged, even if they contain + a [default] value due to not present in the map when it was [edit]ed. *) + end = +struct + type key = MX.key + + module Entry = struct + type 'a t = + { key : MX.key + ; mutable value : 'a + ; mutable dirty : bool + ; dirty_cache : 'a t HX.t } + + let content { value; _ } = value + + let set_dirty handle = + if not handle.dirty then ( + handle.dirty <- true; + HX.replace handle.dirty_cache handle.key handle + ) + + let set_content handle value = + set_dirty handle; + handle.value <- value + end + + type 'a t = + { values : 'a MX.t + ; handles : 'a Entry.t HX.t + ; dirty_cache : 'a Entry.t HX.t + ; default : MX.key -> 'a } + + let entry t r = + try HX.find t.handles r with Not_found -> + let handle = + { Entry.key = r + ; value = (try MX.find r t.values with Not_found -> t.default r) + ; dirty = false + ; dirty_cache = t.dirty_cache } + in + HX.replace t.handles r handle; + handle + + let edit ~default t = + let size = 17 in + { values = t + ; handles = HX.create size + ; dirty_cache = HX.create size + ; default } + + let snapshot t = + let persistent = t.values in + HX.fold (fun repr handle t -> + (* NB: we are in the [dirty_cache] so we know that the domain has been + updated. *) + MX.add repr (Entry.content handle) t + ) t.dirty_cache persistent +end + module type Domain = sig type t (** The type of domains for a single value. @@ -239,6 +369,12 @@ module type Domain = sig val filter_ty : Ty.t -> bool (** Filter for the types of values this domain can be attached to. *) + type constant + (** The type of constant values. *) + + val constant : constant -> t + (** [constant c] returns the singleton domain {m \{ c \}}. *) + val unknown : Ty.t -> t (** [unknown ty] returns a full domain for values of type [t]. @@ -256,634 +392,865 @@ module type Domain = sig @raise Inconsistent if [d1] and [d2] are not compatible (the intersection would be empty). *) +end +module type OffsetDomain = sig + (** This module represents domains to which (constant) offsets can be added or + removed. It extends the [Domain] signature. *) - val fold_leaves : (X.r -> t -> 'a -> 'a) -> X.r -> t -> 'a -> 'a - (** [fold_leaves f r t acc] folds [f] over the leaves of [r], deconstructing - the domain [t] according to the structure of [r]. + include Domain + + val add_offset : t -> constant -> t + (** [add_offset ofs d] adds the offset [ofs] to domain [d]. *) + + val sub_offset : t -> constant -> t + (** [sub_offset ofs d] removes the offset [ofs] from domain [d]. *) +end + +module type EphemeralDomainMap = sig + (** This module provides a signature for ephemeral domain maps: imperative + mappings from some key type to a domain type. *) + + type t + (** The type of ephemeral domain maps, i.e. an imperative structure mapping + keys to their current domain. *) - It is assumed that [t] already contains any justification required for - it to apply to [r]. + type key + (** The type of keys in the ephemeral map. *) - @raise Inconsistent if [r] cannot possibly be in the domain of [t]. *) + type domain + (** The type of domains. *) - val map_leaves : (X.r -> t) -> X.r -> t - (** [map_leaves f r] is the "inverse" of [fold_leaves] in the sense that - it rebuilds a domain for [r] by using [f] to access the domain for each - of [r]'s leaves. *) + module Entry : sig + type t + (** A mutable entry associated with a given key. Can be used to access and + update the associated domain imperatively. A single (physical) entry is + associated with a given key. *) + + val domain : t -> domain + (** Return the domain associated with this entry. *) + + val set_domain : t -> domain -> unit + (** Intersect the domain associated with this entry and the provided + [domain]. The explanation [ex] justifies that the [domain] applies to + the entry's key. + + @raise Domain.Inconsistent if the intersection is empty. *) + end + + val entry : t -> key -> Entry.t + (** [entry t k] returns the [handle] associated with [k]. + + There is a unique entry associated with each key [k] that is created + on-the-fly when [handle t k] is called for the first time. + + The domain associated with the entry is initialized from the underlying + persistent domain the first time it is accessed, and updated with + [update]. *) end -module type Domains = sig - (** Extended signature for global domains. *) +module type OrderedType = sig + (** Module signature for an ordered type equipped with a [compare] function. + + This is similar to [Set.OrderedType] and [Map.OrderedType], but includes + pre-built [Set] and [Map] modules. *) + + type t - include Uf.GlobalDomain + val pp : t Fmt.t - type elt - (** The type of domains contained in the map. Each domain of type [elt] - applies to a single semantic value. *) + val compare : t -> t -> int - val get : X.r -> t -> elt - (** [get r t] returns the domain currently associated with [r] in [t]. *) + module Set : Set.S with type elt = t - val fold_leaves : (X.r -> elt -> 'a -> 'a) -> t -> 'a -> 'a - (** [fold f t acc] folds [f] over all the domains in [t] that are associated - with leaves. *) + module Map : Map.S with type key = t +end - val has_changed : t -> bool - (** Returns [true] if any element is marked as changed. This can be used to - avoid unnecessary calls to [edit]. +module type ComparableType = sig + (** Module signature combining [OrderedType] and [Hashtbl.HashedType]. - Elements are marked as changed when their domain shrinks due to a call to - [subst], or through the ephemeral API. Elements can be unmarked by - [clear_changed] in the ephemeral API. *) + This includes a pre-built [Table] module that implements the [Hashtbl.S] + signature. *) - module Ephemeral : sig - type handle - (** A mutable handle to the domain associated with a semantic value. Can be - used to access and update the domain. *) + include OrderedType - val (!!) : handle -> elt - (** Return the domain associated with the [handle]. *) + val equal : t -> t -> bool - val update : ex:Explanation.t -> handle -> elt -> unit - (** Intersect the domain associated with the [handle] with the provided - [domain]. The explanation [ex] justifies that the [domain] applies to - the [handle]'s representative. + val hash : t -> int - If this changes the domain associated with the handle, the handle is - marked as changed. + module Table : Hashtbl.S with type key = t +end - @raise Domain.Inconsistent if the intersection is empty. *) +module DomainMap + (X : ComparableType) + (D : Domain) + : sig + (** A persistent map to a domain type, with an ephemeral interface. *) type t - (** Mutable mappings from semantic values to [domain]s. *) + (** The type of domain maps. *) - val handle : t -> X.r -> handle - (** [handle t r] returns the [handle] associated with [r]. + val pp : t Fmt.t + (** Pretty-printer for domain maps. *) - There is a unique handle associated with each semantic value [r] that is - created on-the-fly when [handle t r] is called for the first time. + val empty : t + (** The empty domain map. *) - The domain associated with the handle is initialized from the - underlying persistent domain the first time it is accessed, and updated - with [update]. *) + type key = X.t + (** The type of keys in the map. *) - val structural_propagation : t -> X.r -> unit - (** Perform structural propagation for the given representative. + type domain = D.t + (** The type of per-variable domains. *) - More precisely, if [r] is a leaf, the domain of [r] is propagated to any - semantic value that contains [r] as a leaf according to the structure of - that semantic value (using [Domain.map_leaves]); if [r] is not a leaf, - its domain is propagated to any of the leaves it contains. + val find : key -> t -> domain + (** Find the domain associatd with the given key. - We only perform *forward* structural propagation: if structural - propagation causes a domain of a leaf or parent to be changed, then we - only mark that leaf or parent as changed. + @raise Not_found if there is no domain associated with the key. *) - @raise Inconsistent if an inconsistency if detected during structural - propagation. *) + val add : key -> domain -> t -> t + (** Adds a domain associated with a given key. - val iter_changed : (X.r -> unit) -> t -> unit - (** Iterate over all the semantic values that have been marked as changed - since the last call to [clear_changed]. Values are marked as changed by - [update] whenever their domain shrinks. + {b Warning}: If the key is not constant, [add] updates the domain + associated with the variable part of the key, and hence influences the + domains of other keys that have the same variable part as this key. *) - {b Warning}: The behavior is not specified if the ephemeral domain is - modified during iteration, such as by calling [update] or - [structural_propagation]. *) + val remove : key -> t -> t + (** Removes the domain associated with a single variable. This will + effectively remove the domains associated with all keys that have the + same variable part. *) - val clear_changed : t -> unit - (** Remove the [changed] flag from all values. *) - end + val needs_propagation : t -> bool + (** Returns [true] if the domain map needs propagation, i.e. if the domain + associated with any variable has changed. *) - val edit : t -> Ephemeral.t - (** [edit d] returns an ephemeral version of the domain that can be used for - editing. *) + module Ephemeral : EphemeralDomainMap + with type key = key and type domain = domain - val snapshot : Ephemeral.t -> t - (** [snapshot e] returns a persistent version of [e]. *) -end + val edit : + notify:(key -> unit) -> default:(key -> domain) -> t -> Ephemeral.t + (** Create an ephemeral domain map from the current domain map. -module Domains_make(Domain : Domain) : Domains with type elt = Domain.t = -struct - type elt = Domain.t + [notify] will be called whenever the domain associated with a variable + changes. *) - exception Inconsistent = Domain.Inconsistent + val snapshot : Ephemeral.t -> t + (** Convert back a (modified) ephemeral domain map into a persistent one. *) + end - type t = { - domains : Domain.t MX.t ; - (** Map from tracked representatives to their domain *) += +struct + module MX = X.Map + module SX = X.Set + module HX = X.Table + module EX = EphemeralMap(MX)(HX) - changed : SX.t ; - (** Representatives whose domain has changed since the last flush *) + type t = + { domains : D.t MX.t + ; changed : SX.t } - leaves_map : SX.t MX.t ; - (** Map from leaves to the *tracked* representatives that contains them *) - } + type key = X.t - type _ Uf.id += Id : t Uf.id + type domain = D.t let pp ppf t = - Fmt.(iter_bindings ~sep:semi MX.iter - (box @@ pair ~sep:(any " ->@ ") X.print Domain.pp) - ) + Fmt.iter_bindings ~sep:Fmt.semi MX.iter + (Fmt.box @@ Fmt.pair ~sep:(Fmt.any " ->@ ") X.pp D.pp) ppf t.domains let empty = - { domains = MX.empty ; changed = SX.empty ; leaves_map = MX.empty } - - let filter_ty = Domain.filter_ty - - let r_add r leaves_map = - List.fold_left (fun leaves_map leaf -> - MX.update leaf (function - | Some parents -> Some (SX.add r parents) - | None -> Some (SX.singleton r) - ) leaves_map - ) leaves_map (X.leaves r) - - let create_domain r = - Domain.map_leaves (fun r -> - Domain.unknown (X.type_info r) - ) r - - let add r t = - if MX.mem r t.domains then t else - (* Note: we do not need to mark [r] as needing propagation, because no - constraints applied to it yet. Any constraint that apply to [r] will - already be marked as pending due to being newly added. *) - let d = create_domain r in - let domains = MX.add r d t.domains in - let leaves_map = r_add r t.leaves_map in - { t with domains; leaves_map } - - let r_remove r leaves_map = - List.fold_left (fun leaves_map leaf -> - MX.update leaf (function - | Some parents -> - let parents = SX.remove r parents in - if SX.is_empty parents then None else Some parents - | None -> None - ) leaves_map - ) leaves_map (X.leaves r) - - let remove r t = - let changed = SX.remove r t.changed in - let domains = MX.remove r t.domains in - let leaves_map = r_remove r t.leaves_map in - { changed; domains; leaves_map } - - let get r t = - (* We need to catch [Not_found] because of fresh terms that can be added - by the solver and for which we don't call [add]. Note that in this - case, only structural constraints can apply to [r]. *) - try MX.find r t.domains with Not_found -> create_domain r - - (* Marked as unsafe because we trust the [changed] flag from the caller. *) - let unsafe_update ?(changed = true) r d t = - match MX.find r t.domains with - | od -> - (* Both domains are already valid for [r], we can intersect them - without additional justifications. *) - let d = Domain.intersect od d in - if Domain.equal od d then - t - else - let domains = MX.add r d t.domains in - let changed = if changed then SX.add r t.changed else t.changed in - { t with domains; changed } - | exception Not_found -> - (* We need to catch [Not_found] because of fresh terms that can be added - by the solver and for which we don't call [add]. *) - let d = Domain.intersect d (create_domain r) in - let domains = MX.add r d t.domains in - let changed = if changed then SX.add r t.changed else t.changed in - let leaves_map = r_add r t.leaves_map in - { domains; changed; leaves_map } - - let fold_leaves f t acc = - MX.fold (fun r _ acc -> - f r (get r t) acc - ) t.leaves_map acc + { domains = MX.empty + ; changed = SX.empty } - let subst ~ex rr nrr t = - (* Need to add [ex] to be a valid domain for [nrr] *) - let d = Domain.add_explanation ~ex (get rr t) in - let changed = SX.mem rr t.changed in - let t = remove rr t in - match MX.find nrr t.domains with - | nd -> - (* If there is an existing domain for [nrr], there might be - constraints applying to [nrr] prior to the substitution, and the - constraints that used to apply to [rr] will also apply to [nrr] - after the substitution. - - We need to notify changed to either of these constraints, so we - must notify if the domain is different from *either* the old - domain of [rr] or the old domain of [nrr]. *) - let nnd = Domain.intersect d nd in - let nrr_changed = not (Domain.equal nnd nd) in - let rr_changed = not (Domain.equal nnd d) in - let domains = - if nrr_changed then MX.add nrr nnd t.domains else t.domains - in - let changed = changed || rr_changed || nrr_changed in - let changed = - if changed then SX.add nrr t.changed else t.changed - in - { t with domains; changed } - | exception Not_found -> - (* If there is no existing domain for [nr], there were no - constraints applying to [nr] prior to the substitution. - - The only constraints that need to be notified are those that - were applying to [r], and they only need to be notified if the - new domain is different from the old domain of [r]. *) - let default = create_domain nrr in - let nd = Domain.intersect d default in - let rr_changed = not (Domain.equal nd d) in - (* Make sure to not add more constraints than necessary for the - representative domain. *) - let nd = if Domain.equal nd default then default else nd in - let domains = MX.add nrr nd t.domains in - let leaves_map = r_add nrr t.leaves_map in - let changed = changed || rr_changed in - let changed = - if changed then SX.add nrr t.changed else t.changed - in - { domains; changed; leaves_map } + let find x t = MX.find x t.domains + + let remove x t = + { domains = MX.remove x t.domains + ; changed = SX.remove x t.changed } - let has_changed t = - not @@ SX.is_empty t.changed + let add x d t = { t with domains = MX.add x d t.domains } + + let needs_propagation t = not (SX.is_empty t.changed) module Ephemeral = struct - type handle = - { repr : X.r - ; mutable domain : Domain.t - ; mutable dirty : bool - ; dirty_cache : handle HX.t - ; mutable changed : bool - ; changed_set : handle HX.t - } + type nonrec key = key + type nonrec domain = domain + + module Entry = struct + type t = + { entry : domain EX.Entry.t + ; key : key + ; notify : X.t -> unit } + + let domain { entry ; _ } = EX.Entry.content entry + + let set_domain { entry ; notify ; key } dom = + EX.Entry.set_content entry @@ dom; + notify key + end + + type t = + { domains : domain EX.t + ; notify : X.t -> unit } + + let entry t x = + { Entry.entry = EX.entry t.domains x + ; key = x + ; notify = t.notify } + end - let (!!) handle = handle.domain + let edit ~notify ~default t = + SX.iter notify t .changed; - let set_dirty handle = - if not handle.dirty then ( - handle.dirty <- true; - HX.replace handle.dirty_cache handle.repr handle - ) + { Ephemeral.domains = EX.edit ~default t.domains + ; notify } - let set_changed handle = - if not handle.changed then ( - set_dirty handle; - handle.changed <- true; - HX.replace handle.changed_set handle.repr handle - ) + let snapshot t = + { domains = EX.snapshot t.Ephemeral.domains + ; changed = SX.empty } +end - let update ~ex handle domain = - let domain = Domain.add_explanation ~ex domain in - let domain = Domain.intersect handle.domain domain in - if not (Domain.equal domain handle.domain) then ( - set_changed handle; - handle.domain <- domain - ) - type nonrec t = - { persistent : t - ; handles : handle HX.t - ; dirty_cache : handle HX.t - ; changed_set : handle HX.t } - - let handle t r = - try HX.find t.handles r with Not_found -> - let handle = - { repr = r - ; domain = get r t.persistent - ; dirty = false - ; dirty_cache = t.dirty_cache - ; changed = false - ; changed_set = t.changed_set } - in - HX.add t.handles r handle; - handle - - let structural_propagation t r = - (* Structural propagation is always correct and does not require - explanations because it follows the structure of the semantic value - itself. *) - let get r = !!(handle t r) in - let update r d = update ~ex:Explanation.empty (handle t r) d in - if X.is_a_leaf r then - match MX.find r t.persistent.leaves_map with - | parents -> - SX.iter (fun parent -> - if X.is_a_leaf parent then - assert (X.equal r parent) - else - update parent (Domain.map_leaves get parent) - ) parents - | exception Not_found -> () - else - Domain.fold_leaves (fun r d () -> update r d) r (get r) () +module BinRel(X : OrderedType)(W : OrderedType) : sig + (** This module provides a thin abstraction to keep track of binary relations + between values of two different types. *) - let iter_changed f t = HX.iter (fun r _ -> f r) t.changed_set + type t + (** The type of binary relations between [X.t] and [W.t]. *) - let clear_changed t = - HX.iter (fun _ h -> h.changed <- false) t.changed_set; - HX.clear t.changed_set - end + val empty : t + (** The empty relation. *) - let edit t = - let size = 17 in - let ephemeral = - { Ephemeral.persistent = { t with changed = SX.empty } - ; handles = HX.create size - ; dirty_cache = HX.create size - ; changed_set = HX.create size } + val add : X.t -> W.t -> t -> t + (** [add x w r] adds the tuple [(x, w)] to the relation. *) + + val add_many : X.t -> W.Set.t -> t -> t + + val range : X.t -> t -> W.Set.t + + val remove_dom : X.t -> t -> t + (** [remove_dom x r] removes all tuples of the form [(x, _)] from the + relation. *) + + val remove_range : W.t -> t -> t + (** [remove_range w r] removes all tuples of the form [(_, w)] from the + relation. *) + + val transfer_dom : X.t -> X.t -> t -> t + (** [transfer_dom x x' r] replaces all tuples of the form [(x, w)] in the + relation with the corresponding [(x', w)] tuple. *) + + val iter_range : X.t -> (W.t -> unit) -> t -> unit + (** [iter_range x f r] calls [f] on all the [w] such that [(x, w)] is in the + relation. *) + + val fold_range : X.t -> (W.t -> 'a -> 'a) -> t -> 'a -> 'a + (** [fold_range x f r acc] folds [f] over all the [w] such that [(x, w)] is in + the relation.*) +end = struct + module MX = X.Map + module MW = W.Map + module SX = X.Set + module SW = W.Set + + type t = + { watches : SW.t MX.t ; + (** Reverse map from variables to their watches. Used to trigger watches + when a domain changes. *) + + watching : SX.t MW.t + (** Map from watches to the variables they watch. Used to be able to + remove watches. *) + } + + let range x t = + try MX.find x t.watches with Not_found -> W.Set.empty + + let empty = + { watches = MX.empty + ; watching = MW.empty } + + let add x w t = + let watches = + MX.update x (function + | None -> Some (SW.singleton w) + | Some ws -> Some (SW.add w ws)) t.watches + and watching = + MW.update w (function + | None -> Some (SX.singleton x) + | Some xs -> Some (SX.add x xs)) t.watching + in + { watches ; watching } + + let add_many x ws t = + let watches = + MX.update x (function + | None -> Some ws + | Some ws' -> Some (SW.union ws ws')) t.watches + and watching = + SW.fold (fun w watching -> + MW.update w (function + | None -> Some (SX.singleton x) + | Some xs -> Some (SX.add x xs)) watching + ) ws t.watching in - SX.iter (fun r -> - Ephemeral.set_changed (Ephemeral.handle ephemeral r) - ) t.changed; - ephemeral + { watches ; watching } + + let remove_range w t = + match MW.find w t.watching with + | xs -> + let watches = + SX.fold (fun x watches -> + MX.update x (function + | None -> + (* maps must be mutual inverses *) + assert false + | Some ws -> + let ws = SW.remove w ws in + if SW.is_empty ws then None else Some ws + ) watches + ) xs t.watches + in + let watching = MW.remove w t.watching in + { watches ; watching } + | exception Not_found -> t - let snapshot t = - assert (SX.is_empty t.Ephemeral.persistent.changed); - HX.fold (fun repr handle domains -> - unsafe_update - ~changed:handle.Ephemeral.changed repr handle.domain domains - ) t.Ephemeral.dirty_cache t.persistent + let remove_dom x t = + match MX.find x t.watches with + | ws -> + let watching = + SW.fold (fun w watching -> + MW.update w (function + | None -> + (* maps must be mutual inverses *) + assert false + | Some xs -> + let xs = SX.remove x xs in + if SX.is_empty xs then None else Some xs + ) watching + ) ws t.watching + and watches = MX.remove x t.watches in + { watches ; watching } + | exception Not_found -> t + + let fold_range x f t acc = + match MX.find x t.watches with + | ws -> SW.fold f ws acc + | exception Not_found -> acc + + let iter_range x f t = + match MX.find x t.watches with + | ws -> SW.iter f ws + | exception Not_found -> () + + let transfer_dom x x' t = + match MX.find x t.watches with + | ws -> + let watching = + SW.fold (fun w watching -> + MW.update w (function + | None -> + (* maps must be mutual inverses *) + assert false + | Some xs -> + Some (SX.add x' (SX.remove x xs)) + ) watching + ) ws t.watching + and watches = + MX.update x' (function + | None -> Some ws + | Some ws' -> Some (SW.union ws ws') + ) (MX.remove x t.watches) + in + { watches ; watching } + | exception Not_found -> t end -(** The ['c acts] type is used to register new facts and constraints in - [Constraint.simplify]. *) -type 'c acts = - { acts_add_lit_view : X.r L.view -> unit - (** Assert a semantic literal. *) - ; acts_add_eq : X.r -> X.r -> unit - (** Assert equality between two semantic values. *) - ; acts_add_constraint : 'c -> unit - (** Assert a new constraint. *) - } +(** Implementation of the [ComparableType] interface for semantic values. *) +module XComparable : ComparableType with type t = X.r = struct + type t = X.r -module type Constraint = sig - type t - (** The type of constraints. + let pp = X.print - Constraints apply to semantic values of type [X.r] as arguments. *) + let equal = X.equal - val pp : t Fmt.t - (** Pretty-printer for constraints. *) + let hash = X.hash - val compare : t -> t -> int - (** Comparison function for constraints. The comparison function is - arbitrary and has no semantic meaning. You should not depend on any of - its properties, other than it defines an (arbitrary) total order on - constraint representations. *) + let compare = X.hash_cmp - val fold_args : (X.r -> 'a -> 'a) -> t -> 'a -> 'a - (** [fold_args f c acc] folds function [f] over the arguments of constraint - [c]. + module Set = SX - During propagation, the constraint {b MUST} only look at (and update) - the domains associated of its arguments; it is not allowed to look at - the domains of other semantic values. This allows efficient updates of - the pending constraints. *) + module Map = MX - val subst : X.r -> X.r -> t -> t - (** [subst p v cs] replaces all the instances of [p] with [v] in the - constraint. + module Table = HX +end - Substitution can perform constraint simplification. *) +module type NormalForm = sig + (** Module signature for normal form computation. *) - val simplify : t -> t acts -> bool - (** [simplify c acts] simplifies the constraint [c] by calling appropriate - functions on [acts]. + type constant + (** The type of constant values. *) - {b Note}: All the facts and constraints added through [acts] must be - logically implied by [c] {b only}. Doing otherwise is a {b soundness bug}. + type atom + (** The type of atomic variables that cannot be decomposed further. *) - Returns [true] if the constraint has been fully simplified and can - be removed, and [false] otherwise. + type composite + (** The type of composite variables that are obtained through a combination of + atomic variables (e.g. a multi-variate polynomial). *) - {b Note}: Returning [true] will cause the constraint to be removed, even - if it was re-added with [acts_add_constraint]. If you want to add new - facts/constraints but keep the existing constraint (usually a bad idea), - return [false] instead. *) -end + type t = + | Constant of constant + (** A constant value. *) + | Atom of atom * constant + (** An atomic variable with a constant offset. *) + | Composite of composite * constant + (** A composite variable with a constant offset. *) + (** The type of normal forms. *) -type 'a explained = { value : 'a ; explanation : Explanation.t } + type expr + (** The underlying type of non-normalized expressions. *) -let explained ~ex value = { value ; explanation = ex } + val normal_form : expr -> t + (** [normal_form e] computes the normal form of expression [e]. *) +end -module Constraints_make(Constraint : Constraint) : sig - type t - (** The type of constraint sets. A constraint set records a set of - constraints that applies to semantic values, and remembers the relation - between constraints and semantic values. +module type CompositeType = sig + (** Extension of the [ComparableType] signature for composite types, i.e. + types that are built up from a collection of smaller components. *) - The constraints applying to a given semantic value can be recovered using - the [iter_pending] functions. + include ComparableType - New constraints are marked as "pending" when added to the constraint set - (whether by a call to [add] or following a substitution). These - constraints should ultimately be propagated; they can be accessed through - the [iter_pending]. Once pending constraints have been propagated, the - "pending" constraints should be cleared with [clear_pending]. *) + type atom + (** The type of atoms that build up a composite value. *) - val pp : t Fmt.t - (** Pretty-printer for constraint sets. *) + val fold : (atom -> 'a -> 'a) -> t -> 'a -> 'a + (** [fold f c acc] folds [f] over all the atoms that make up [c]. *) +end - val empty : t - (** The empty constraint set. *) - - val add : ex:Explanation.t -> Constraint.t -> t -> t - (** [add ~ex c t] adds the constraint [c] to the set [t]. - - The explanation [ex] justifies that the constraint [c] holds. If the same - constraint is added multiple times with different explanations, only one - of the explanations for the constraint will be kept. *) - - val subst : ex:Explanation.t -> X.r -> X.r -> t -> t - (** [subst ~ex p v t] replaces all instances of [p] with [v] in the - constraints. - - The explanation [ex] justifies the equality [p = v]. *) - - val iter_parents : (Constraint.t explained -> unit) -> X.r -> t -> unit - (** [iter_parents f r t] calls [f] on all the constraints that apply directly - to [r] (precisely, all the constraints [r] is an argument of). *) - - val iter_pending : (Constraint.t explained -> unit) -> t -> unit - (** [iter_pending f t] calls [f] on all the constraints currently marked as - pending. Constraints are marked as pending when they are added, including - when a new constraint is added due to substitution of an old constraint - (whether the old constraint was pending or not). *) - - val clear_pending : t -> t - (** [clear_pending t] returns a copy of [t] except that no constraints are - marked as pending. *) - - val has_pending : t -> bool - (** [has_pending t] returns [true] if there is any constraint marked as - pending. Hence if [has_pending t] returns [false], [iter_pending] and - [clear_pending] are guaranteed to be no-ops. Should only be used for - optimization. *) - - val fold_args : (X.r -> 'a -> 'a) -> t -> 'a -> 'a - (** [fold_args f t acc] folds [f] over all the term representatives that are - arguments of at least one constraint. *) - - val simplify_pending : - (X.r L.view * Explanation.t) list -> t -> - (X.r L.view * Explanation.t) list * t - (** Simplify the pending constraints. This takes as argument a list of - (explained) literals, and returns a list of (explained) literals, so - that constraint simplification is able to propagate new literals - (typically equalities) to the UF module. *) -end = struct - module CS = Set.Make(struct - type t = Constraint.t explained +module type CompositeDomain = sig + (** Module signature to build a domain for a composite type from the domain of + its component atoms. *) - let compare a b = Constraint.compare a.value b.value - end) + type var + (** The type of (composite) variables. *) - type t = { - args_map : CS.t MX.t ; - (** Mapping from semantic values to constraints involving them *) + type atom + (** The type of atomic variables. *) - leaves_map : CS.t MX.t ; - (** Mapping from semantic values to constraints they are a leaf of *) + type domain + (** The type of domains we are building. *) - active : CS.t ; - (** Set of all currently active constraints, i.e. constraints that must - hold in a model and will be propagated. *) + val map_domain : (atom -> domain) -> var -> domain + (** [map_domain f c] constructs a domain for [c] from a function [f] that + returns the domain of an atom. *) +end - pending : CS.t ; - (** Set of active constraints that have not yet been propagated *) +type ('a, 'c, 'w) events = + { evt_atomic_change : 'a -> unit + ; evt_composite_change : 'c -> unit + (** Called by the ephemeral interface when the domain associated with a + variable changes. *) + ; evt_watch_trigger : 'w -> unit + (** Called by the ephemeral interface when a watcher is triggered. *) } +(** Handlers for events used by the ephemeral interface. *) + +module type VariableType = sig + (** Extension of the [ComparableType] signature for variables that have an + associated type. *) + + include ComparableType + + val type_info : t -> Ty.t + (** [type_info x] returns the type of variable [x]. *) +end + +module Domains_make + (D : OffsetDomain) + (A : VariableType) + (C : CompositeType with type atom = A.t) + (CD : CompositeDomain + with type var = C.t + and type atom = A.t + and type domain = D.t) + (NF : NormalForm + with type atom = A.t + and type composite = C.t + and type constant = D.constant + and type expr = X.r) + (W : OrderedType) + : sig + include Uf.GlobalDomain + + val get : X.r -> t -> D.t + (** [get r t] returns the domain associated with semantic value [r]. *) + + val watch : W.t -> X.r -> t -> t + (** [watch w r t] associated the watch [w] with the domain of semantic value + [r]. The watch [w] is triggered whenever the domain associated with [r] + changes, and is preserved across substitutions (i.e. if [r] becomes + [nr], [w] will be transfered to [nr]). + + {b Note}: The watch [w] is also immediately triggered for a first + propagation. *) + + val unwatch : W.t -> t -> t + (** [unwatch w] removes [w] from all watch lists. It will no longer be + triggered. + + {b Note}: If [w] has already been triggered, it is not removed from the + triggered list. *) + + val needs_propagation : t -> bool + (** Returns [true] if the domains needs propagation, i.e. if any variable's + domain has changed. *) + + val variables : t -> A.Set.t + (** Returns the set of atomic variables that are currently being tracked. *) + + val parents : t -> C.Set.t A.Map.t + (** Returns a map from atomic variables to all the composite variables that + contain them and are currently being tracked. *) + + module Ephemeral : EphemeralDomainMap + with type key = X.r + and type domain = D.t - let pp ppf { active; _ } = - Fmt.( - braces @@ hvbox @@ - iter ~sep:semi CS.iter @@ - using (fun { value; _ } -> value) @@ - box ~indent:2 @@ braces @@ - Constraint.pp - ) ppf active + val edit : events:(A.t, C.t, W.t) events -> t -> Ephemeral.t + (** [edit ~events t] returns an ephemeral copy of the domains for edition. + + The [events] argument is used to notify the caller about domain changes + and watches being triggered. + + {b Note}: Any domain that has changed or watches that have been + triggered through the persistent API (e.g. due to substitutions) are + immediately notified through the appropriare [events] callback. *) + + val snapshot : Ephemeral.t -> t + (** Converts back an ephemeral domain into a persistent one. *) + end += +struct + module DMA = DomainMap(A)(D) + module DMC = DomainMap(C)(D) + + module AW = BinRel(A)(W) + module CW = BinRel(C)(W) + + type t = + { atoms : DMA.t + (* Map from atomic variables to their (non-default) domain. *) + ; atom_watches : AW.t + (* Map (and reverse map) from atomic variables to the watches that must be + triggered when their domain gets updated. *) + ; variables : A.Set.t + (* Set of all atomic variables being tracked. *) + ; composites : DMC.t + (* Map from composite variables to their (non-default) domain. *) + ; composite_watches : CW.t + (* Map (and reverse map) from composite variables to the watches that must + be triggered when their domain gets udpated. *) + ; parents : C.Set.t A.Map.t + (* Reverse map from atomic variables to the composite variables that + contain them. Useful for structural propagation. *) + ; triggers : W.Set.t + (* Watches that have been triggered. They will be immediately notified + when [edit] is called. *) + } + + let pp ppf { atoms ; composites ; _ } = + DMA.pp ppf atoms; + DMC.pp ppf composites let empty = - { args_map = MX.empty - ; leaves_map = MX.empty - ; active = CS.empty - ; pending = CS.empty } - - let cs_add c r cs_map = - MX.update r (function - | Some cs -> Some (CS.add c cs) - | None -> Some (CS.singleton c) - ) cs_map - - let fold_leaves f c acc = - Constraint.fold_args (fun r acc -> - List.fold_left (fun acc r -> f r acc) acc (X.leaves r) - ) c acc - - let add ~ex c t = - let c = explained ~ex c in - (* Note: use [CS.find] here, not [CS.mem], to ensure we use the same - explanation for [c] in the [pending] and [active] sets. *) - if CS.mem c t.active then t else - let active = CS.add c t.active in - let args_map = - Constraint.fold_args (cs_add c) c.value t.args_map - in - let leaves_map = fold_leaves (cs_add c) c.value t.leaves_map in - let pending = CS.add c t.pending in - { active; args_map; leaves_map; pending } - - let cs_remove c r cs_map = - MX.update r (function - | Some cs -> - let cs = CS.remove c cs in - if CS.is_empty cs then None else Some cs - | None -> None - ) cs_map - - let remove c t = - let active = CS.remove c t.active in - let args_map = - Constraint.fold_args (cs_remove c) c.value t.args_map - in - let leaves_map = fold_leaves (cs_remove c) c.value t.leaves_map in - let pending = CS.remove c t.pending in - { active; args_map; leaves_map; pending } + { atoms = DMA.empty + ; atom_watches = AW.empty + ; variables = A.Set.empty + ; composites = DMC.empty + ; composite_watches = CW.empty + ; parents = A.Map.empty + ; triggers = W.Set.empty + } + + type _ Uf.id += Id : t Uf.id + + let filter_ty = D.filter_ty + + exception Inconsistent = D.Inconsistent + + let watch w r t = + let t = { t with triggers = W.Set.add w t.triggers } in + match NF.normal_form r with + | Constant _ -> t + | Atom (a, _) -> + { t with atom_watches = AW.add a w t.atom_watches } + | Composite (c, _) -> + { t with composite_watches = CW.add c w t.composite_watches } + + let unwatch w t = + { atoms = t.atoms + ; atom_watches = AW.remove_range w t.atom_watches + ; variables = t.variables + ; composites = t.composites + ; composite_watches = CW.remove_range w t.composite_watches + ; parents = t.parents + ; triggers = t.triggers } + + let needs_propagation t = + DMA.needs_propagation t.atoms || + DMC.needs_propagation t.composites || + not (W.Set.is_empty t.triggers) + + let variables { variables ; _ } = variables + + let parents { parents ; _ } = parents + + let track c parents = + C.fold (fun a t -> + A.Map.update a (function + | Some cs -> Some (C.Set.add c cs) + | None -> Some (C.Set.singleton c) + ) t + ) c parents + + let untrack c parents = + C.fold (fun a t -> + A.Map.update a (function + | Some cs -> + let cs = C.Set.remove c cs in + if C.Set.is_empty cs then None else Some cs + | None -> None + ) t + ) c parents + + let init r t = + match NF.normal_form r with + | Constant _ -> t + | Atom (a, _) -> + { t with variables = A.Set.add a t.variables } + | Composite (c, _) -> + { t with parents = track c t.parents } + + let default_atom a = D.unknown (A.type_info a) + + let find_or_default_atom a t = + try DMA.find a t.atoms + with Not_found -> default_atom a + + let default_composite c = CD.map_domain default_atom c + + let find_or_default_composite c t = + try DMC.find c t.composites + with Not_found -> default_composite c + + let find_or_default x t = + match x with + | NF.Constant c -> + D.constant c + | NF.Atom (a, o) -> + D.add_offset (find_or_default_atom a t) o + | NF.Composite (c, o) -> + D.add_offset (find_or_default_composite c t) o + + let get r t = find_or_default (NF.normal_form r) t let subst ~ex rr nrr t = - match MX.find rr t.leaves_map with - | cs -> - CS.fold (fun c t -> - let t = remove c t in - let ex = Explanation.union ex c.explanation in - add ~ex (Constraint.subst rr nrr c.value) t - ) cs t - | exception Not_found -> t + let rrd, ws, t = + match NF.normal_form rr with + | Constant _ -> invalid_arg "subst: cannot substitute a constant" + | Atom (a, o) -> + let variables = A.Set.remove a t.variables in + D.add_offset (find_or_default_atom a t) o, + AW.range a t.atom_watches, + { t with + atoms = DMA.remove a t.atoms ; + atom_watches = AW.remove_dom a t.atom_watches ; + variables } + | Composite (c, o) -> + let parents = untrack c t.parents in + D.add_offset (find_or_default_composite c t) o, + CW.range c t.composite_watches, + { t with + composites = DMC.remove c t.composites ; + composite_watches = CW.remove_dom c t.composite_watches ; + parents } + in + (* Add [ex] to justify that it applies to [nrr] *) + let rrd = D.add_explanation ~ex rrd in + let nrr_nf = NF.normal_form nrr in + let nrrd = find_or_default nrr_nf t in + let nnrrd = D.intersect nrrd rrd in + let t = + if D.equal nnrrd rrd then t + else { t with triggers = W.Set.union ws t.triggers } + in + let t = + match nrr_nf with + | Constant _ -> t + | Atom (a, _) -> + let atom_watches = AW.add_many a ws t.atom_watches in + let variables = A.Set.add a t.variables in + { t with atom_watches ; variables } + | Composite (c, _) -> + let composite_watches = CW.add_many c ws t.composite_watches in + let parents = track c t.parents in + { t with composite_watches ; parents } + in + if D.equal nnrrd nrrd then t + else + match nrr_nf with + | Constant _ -> + (* [nrrd] is [D.constant c] which must be a singleton; if we + shrunk it, it can only be empty. *) + assert false + | Atom (a, o) -> + let triggers = W.Set.union (AW.range a t.atom_watches) t.triggers in + let atoms = DMA.add a (D.sub_offset nnrrd o) t.atoms in + { t with atoms ; triggers } + | Composite (c, o) -> + let triggers = + W.Set.union (CW.range c t.composite_watches) t.triggers + in + let composites = DMC.add c (D.sub_offset nnrrd o) t.composites in + { t with composites ; triggers } - let iter_parents f r t = - match MX.find r t.args_map with - | cs -> CS.iter f cs - | exception Not_found -> () + module Ephemeral = struct + type key = X.r + type domain = D.t + + module Entry = struct + type t = + | Constant of NF.constant + | Atom of DMA.Ephemeral.Entry.t * NF.constant + | Composite of DMC.Ephemeral.Entry.t * NF.constant + + let domain = function + | Constant c -> D.constant c + | Atom (a, o) -> + D.add_offset (DMA.Ephemeral.Entry.domain a) o + | Composite (c, o) -> + D.add_offset (DMC.Ephemeral.Entry.domain c) o + + let set_domain e d = + match e with + | Constant _ -> assert false + | Atom (a, o) -> + DMA.Ephemeral.Entry.set_domain a (D.sub_offset d o) + | Composite (c, o) -> + DMC.Ephemeral.Entry.set_domain c (D.sub_offset d o) + end + + type t = + { atoms : DMA.Ephemeral.t + ; atom_watches : AW.t + ; variables : A.Set.t + ; composites : DMC.Ephemeral.t + ; composite_watches : CW.t + ; parents : C.Set.t A.Map.t } + + let entry t r = + match NF.normal_form r with + | NF.Constant c -> + Entry.Constant c + | NF.Atom (a, o) -> + Atom (DMA.Ephemeral.entry t.atoms a, o) + | NF.Composite (c, o) -> + Entry.Composite (DMC.Ephemeral.entry t.composites c, o) + end - let iter_pending f t = - CS.iter f t.pending - - let clear_pending t = - { t with pending = CS.empty } - - let has_pending t = not @@ CS.is_empty t.pending - - let fold_args f c acc = - MX.fold (fun r _ acc -> - f r acc - ) c.args_map acc - - let simplify_pending = - (* Recursion needed because adding new constraints changes the pending set - and they also need to be simplified *) - let rec simplify_aux eqs t to_simplify = - let eqs = ref eqs in - let to_add = ref CS.empty in - let t = - CS.fold (fun ({ value; explanation } as c) t -> - let acts_add_lit_view l = - eqs := (l, explanation) :: !eqs - in - let acts_add_eq u v = - acts_add_lit_view (Uf.LX.mkv_eq u v) - in - let acts_add_constraint c = - let c = { value = c; explanation } in - if not (CS.mem c t.active) then - to_add := CS.add c !to_add - in - let acts = - { acts_add_lit_view - ; acts_add_eq - ; acts_add_constraint } in - if Constraint.simplify value acts then - remove c t - else - t - ) to_simplify t - in - let to_add = !to_add in - if CS.is_empty to_add then - !eqs, t - else - let t = CS.fold (fun c t -> add ~ex:c.explanation c.value t) to_add t in - simplify_aux !eqs t to_add + let edit ~events t = + W.Set.iter events.evt_watch_trigger t.triggers; + + let notify_atom a = + events.evt_atomic_change a; + AW.iter_range a events.evt_watch_trigger t.atom_watches + and notify_composite c = + events.evt_composite_change c; + CW.iter_range c events.evt_watch_trigger t.composite_watches in - fun eqs t -> - if CS.is_empty t.pending then eqs, t else - simplify_aux eqs t t.pending + + { Ephemeral.atoms = + DMA.edit + ~notify:notify_atom ~default:default_atom + t.atoms + ; atom_watches = t.atom_watches + ; variables = t.variables + ; composites = + DMC.edit + ~notify:notify_composite ~default:default_composite + t.composites + ; composite_watches = t.composite_watches + ; parents = t.parents } + + let snapshot t = + { atoms = DMA.snapshot t.Ephemeral.atoms + ; atom_watches = t.Ephemeral.atom_watches + ; variables = t.Ephemeral.variables + ; composites = DMC.snapshot t.Ephemeral.composites + ; composite_watches = t.Ephemeral.composite_watches + ; parents = t.Ephemeral.parents + ; triggers = W.Set.empty } +end + +(** Wrapper around an ephemeral domain map to access domains associated with a + representative computed by the [Uf] module. *) +module UfHandle + (D : Domain) + (DM : EphemeralDomainMap with type key = X.r and type domain = D.t) + : sig + include EphemeralDomainMap with type key = X.r and type domain = D.t + + val wrap : Uf.t -> DM.t -> t + end += +struct + type key = X.r + + type domain = DM.domain + + module Entry = struct + type t = + { repr : X.r + ; handle : DM.Entry.t + ; explanation : Explanation.t } + + let domain { repr ; handle ; explanation = ex } = + if Explanation.is_empty ex then DM.Entry.domain handle + else + D.intersect (D.unknown (X.type_info repr)) @@ + D.add_explanation ~ex (DM.Entry.domain handle) + + let set_domain { handle ; explanation = ex ; _ } d = + DM.Entry.set_domain handle (D.add_explanation ~ex d) + end + + type t = + { uf : Uf.t + ; cache : Entry.t HX.t + ; domains : DM.t } + + let entry t r = + try HX.find t.cache r with Not_found -> + let r, explanation = Uf.find_r t.uf r in + let h = + { Entry.repr = r + ; handle = DM.entry t.domains r + ; explanation } + in + HX.replace t.cache r h; h + + let wrap uf t = + { uf ; cache = HX.create 17 ; domains = t } +end + +module HandleNotations + (D : Domain) + (E : EphemeralDomainMap with type domain = D.t) = +struct + let (!!) = E.Entry.domain + + let update ~ex entry domain = + let current = E.Entry.domain entry in + let domain = D.intersect current (D.add_explanation ~ex domain) in + if not (D.equal domain current) then + E.Entry.set_domain entry domain end diff --git a/src/lib/reasoners/uf.ml b/src/lib/reasoners/uf.ml index d1a4d2e30..5c26bc7e5 100644 --- a/src/lib/reasoners/uf.ml +++ b/src/lib/reasoners/uf.ml @@ -88,7 +88,7 @@ module type GlobalDomain = sig val filter_ty : Ty.t -> bool - val add : X.r -> t -> t + val init : X.r -> t -> t exception Inconsistent of Explanation.t @@ -135,7 +135,7 @@ module GlobalDomains = struct let init r t = let ty = X.type_info r in MapI.map (function B ((module D) as dom, d) as b -> - if D.filter_ty ty then B (dom, D.add r d) else b + if D.filter_ty ty then B (dom, D.init r d) else b ) t let add (type a) ((module D) as dom : a global_domain) v t = diff --git a/src/lib/reasoners/uf.mli b/src/lib/reasoners/uf.mli index 8ffd416fb..1aca26260 100644 --- a/src/lib/reasoners/uf.mli +++ b/src/lib/reasoners/uf.mli @@ -68,8 +68,8 @@ module type GlobalDomain = sig of representatives for which [filter_ty (type_info r)] holds will be propagated to this module. *) - val add : r -> t -> t - (** [add r t] is called when the representative [r] is added to the + val init : r -> t -> t + (** [init r t] is called when the representative [r] is added to the union-find, if it has a type that matches [filter_ty]. {b Note}: unlike [Relation.add], this function is called even for