diff --git a/src/lib/reasoners/bitlist.ml b/src/lib/reasoners/bitlist.ml index df93c9b29..65946b564 100644 --- a/src/lib/reasoners/bitlist.ml +++ b/src/lib/reasoners/bitlist.ml @@ -51,7 +51,7 @@ let ones b = { b with bits_clr = Z.zero } let zeroes b = { b with bits_set = Z.zero } -let add_explanation b ex = { b with ex = Ex.union b.ex ex } +let add_explanation ~ex b = { b with ex = Ex.union b.ex ex } let pp ppf { width; bits_set; bits_clr; ex } = for i = width - 1 downto 0 do @@ -80,7 +80,7 @@ let value b = b.bits_set let is_fully_known b = Z.(equal (shift_right (bits_known b + ~$1) b.width) ~$1) -let intersect b1 b2 ex = +let intersect ~ex b1 b2 = let width = b1.width in let bits_set = Z.logor b1.bits_set b2.bits_set in let bits_clr = Z.logor b1.bits_clr b2.bits_clr in diff --git a/src/lib/reasoners/bitlist.mli b/src/lib/reasoners/bitlist.mli index 4ca65c8e0..b556e6fb4 100644 --- a/src/lib/reasoners/bitlist.mli +++ b/src/lib/reasoners/bitlist.mli @@ -67,8 +67,8 @@ val zeroes : t -> t (** [zeroes b] returns a bitlist where the one bits in [b] are replaced with unknown bits. *) -val add_explanation : t -> Explanation.t -> t -(** [add_explanation b ex] adds the explanation [ex] to the bitlist [b]. The +val add_explanation : ex:Explanation.t -> t -> t +(** [add_explanation ~ex b] adds the explanation [ex] to the bitlist [b]. The returned bitlist has both the explanation of [b] and [ex] as explanation. *) val bits_known : t -> Z.t @@ -87,8 +87,8 @@ val value : t -> Z.t [b] is not fully known, then only the known bits (those that are set in [bits_known b]) are meaningful; unknown bits are set to [0]. *) -val intersect : t -> t -> Explanation.t -> t -(** [intersect b1 b2 ex] returns a new bitlist [b] that subsumes both [b1] and +val intersect : ex:Explanation.t -> t -> t -> t +(** [intersect ~ex b1 b2] returns a new bitlist [b] that subsumes both [b1] and [b2]. The explanation [ex] justifies that the two bitlists can be merged. Raises [Inconsistent] if [b1] and [b2] are not compatible (i.e. there are diff --git a/src/lib/reasoners/bitv.ml b/src/lib/reasoners/bitv.ml index 9e370f08c..e7f190e93 100644 --- a/src/lib/reasoners/bitv.ml +++ b/src/lib/reasoners/bitv.ml @@ -285,6 +285,8 @@ let hash_abstract hash = let negate_abstract xs = List.map negate_simple_term xs +let lognot = negate_abstract + type solver_simple_term = tvar alpha_term let pp_solver_simple_term = pp_alpha_term pp_tvar diff --git a/src/lib/reasoners/bitv.mli b/src/lib/reasoners/bitv.mli index 98c53000d..2248ae513 100644 --- a/src/lib/reasoners/bitv.mli +++ b/src/lib/reasoners/bitv.mli @@ -54,6 +54,8 @@ type 'a abstract = 'a simple_term list [size - 1 .. size - sz] inclusive. *) val extract : int -> int -> int -> 'a abstract -> 'a abstract +val lognot : 'a abstract -> 'a abstract + (** [to_Z_opt r] evaluates [r] to an integer if possible. *) val to_Z_opt : 'a abstract -> Z.t option diff --git a/src/lib/reasoners/bitv_rel.ml b/src/lib/reasoners/bitv_rel.ml index 044f20e46..698ac1a88 100644 --- a/src/lib/reasoners/bitv_rel.ml +++ b/src/lib/reasoners/bitv_rel.ml @@ -32,6 +32,7 @@ module E = Expr module Ex = Explanation module Sy = Symbols module X = Shostak.Combine +module SX = Shostak.SXH module L = Xliteral let timer = Timers.M_Bitv @@ -75,181 +76,62 @@ let is_bv_ty = function let is_bv_r r = is_bv_ty @@ X.type_info r -module SX = Shostak.SXH -module MX = Shostak.MXH - -module Domains : sig - type t - (** The type of domain maps. A domain map maps each representative (semantic - value, of type [X.r]) to its associated domain. - - The keys of the domain maps are expected to be current *class - representatives*, i.e. normal forms wrt to the `Uf` module, in which - case we say the domain map is *normalized*. Use `subst` to ensure that - domain maps stay normalized. *) - - val pp : t Fmt.t - (** Pretty-printer for domain maps. *) - - val empty : t - (** Returns an empty domain map. *) - - val update : Ex.t -> X.r -> t -> Bitlist.t -> t - (** [update ex r d bl] intersects the domain of [r] with bitlist [bl]. - - The explanation [ex] justifies that [bl] applies to [r]. - - @raise Bitlist.Inconsistent if the new domain is empty. *) - - val get : X.r -> t -> Bitlist.t - (** [get r d] returns the bitlist currently associated with [r] in [d]. +module Domain : Rel_utils.Domain with type t = Bitlist.t = struct + (* Note: these functions are not in [Bitlist] proper in order to avoid a + (direct) dependency from [Bitlist] to the [Shostak] module. *) - @raise Not_found if there is no bitlist associated with [r] in [d]. *) - - val subst : Ex.t -> X.r -> X.r -> t -> t - (** [subst ex p v d] replaces all the instances of [p] with [v] in any domain, - and merges the corresponding bitlists. - - Use this to ensure that the representation is always normalized. - - The explanation [ex] justifies the equality [p = v]. - - @raise Bitlist.Inconsistent if this causes any domain in [d] to become - empty. *) - - val choose_changed : t -> X.r * t - (** [choose_changed d] returns a pair [r, d'] such that: - - - The domain associated with [r] has changed since the last time - [choose_changed] was called, - - [r] has (by definition) not changed in [d'] *) - - val fold : (X.r -> Bitlist.t -> 'a -> 'a) -> t -> 'a -> 'a - (** [fold f d acc] folds [f] over all the domains associated with variables *) -end = struct - type t = - { bitlists : Bitlist.t MX.t - (** Mapping from semantic values to their bitlist domain. - - Note: this mapping only contains domain for leaves (i.e. uninterpreted - terms or AC symbols); domains associated with more complex semantic - values are rebuilt on-the-fly using the structure of said semantic - values. *) - ; changed : SX.t - (** Elements whose domain has changed since last propagation. *) - } - - let pp ppf t = - Fmt.(iter_bindings ~sep:semi MX.iter - (box @@ pair ~sep:(any " ->@ ") X.print Bitlist.pp) - |> braces - ) - ppf t.bitlists - let empty = { bitlists = MX.empty ; changed = SX.empty } - - let update_leaf ex r t bl = - let changed = ref false in - let bitlists = - MX.update r (function - | Some bl' as o -> - let bl'' = Bitlist.intersect bl bl' ex in - (* Keep simpler explanations, and don't loop adding the domain to - the changed set infinitely. *) - if Bitlist.equal bl' bl'' then - o - else ( - if Options.get_debug_bitv () then - Printer.print_dbg - ~module_name:"Bitv_rel" ~function_name:"Domain.update" - "domain shrunk for %a: %a -> %a" - X.print r Bitlist.pp bl' Bitlist.pp bl''; - changed := true; - Some bl'' - ) - | None -> - changed := true; - Some (Bitlist.add_explanation bl ex) - ) t.bitlists - in - let changed = if !changed then SX.add r t.changed else t.changed in - { changed; bitlists } + include Bitlist - let update_signed ex { Bitv.value; negated } t bl = - let bl = if negated then Bitlist.lognot bl else bl in - update_leaf ex value t bl + let fold_signed f { Bitv.value; negated } bl acc = + let bl = if negated then lognot bl else bl in + f value bl acc - let update ex r t bl = - fst @@ List.fold_left (fun (t, bl) { Bitv.bv; sz } -> + let fold_leaves f r bl acc = + fst @@ List.fold_left (fun (acc, bl) { Bitv.bv; sz } -> (* Extract the bitlist associated with the current component *) - let mid = Bitlist.width bl - sz in + let mid = width bl - sz in let bl_tail = - if mid = 0 then Bitlist.empty else - Bitlist.extract bl 0 (mid - 1) + if mid = 0 then empty else + extract bl 0 (mid - 1) in - let bl = Bitlist.extract bl mid (Bitlist.width bl - 1) in + let bl = extract bl mid (width bl - 1) in match bv with | Bitv.Cte z -> (* Nothing to update, but still check for consistency! *) - ignore @@ Bitlist.intersect bl (Bitlist.exact sz z Ex.empty) ex; - t, bl_tail - | Other r -> update_signed ex r t bl, bl_tail + ignore @@ intersect ~ex:Ex.empty bl (exact sz z Ex.empty); + acc, bl_tail + | Other r -> fold_signed f r bl acc, bl_tail | Ext (r, r_size, i, j) -> (* r = bl -> r = ?^(r_size - j - 1) @ bl @ ?^i *) - assert (i + r_size - j - 1 + Bitlist.width bl = r_size); + assert (i + r_size - j - 1 + width bl = r_size); let hi = Bitlist.unknown (r_size - j - 1) Ex.empty in let lo = Bitlist.unknown i Ex.empty in - update_signed ex r t Bitlist.(hi @ bl @ lo), bl_tail - ) (t, bl) (Shostak.Bitv.embed r) + fold_signed f r (hi @ bl @ lo) acc, bl_tail + ) (acc, bl) (Shostak.Bitv.embed r) - let get_leaf r t = - try MX.find r t.bitlists with - | Not_found -> - match X.type_info r with - | Tbitv n -> Bitlist.unknown n Explanation.empty - | _ -> assert false + let map_signed f { Bitv.value; negated } t = + let bl = f value t in + if negated then lognot bl else bl - let get_signed { Bitv.value; negated } t = - let bl = get_leaf value t in - if negated then Bitlist.lognot bl else bl - - let get r t = + let map_leaves f r acc = List.fold_left (fun bl { Bitv.bv; sz } -> - Bitlist.concat bl @@ + concat bl @@ match bv with - | Bitv.Cte z -> Bitlist.exact sz z Ex.empty - | Other r -> get_signed r t - | Ext (r, _r_size, i, j) -> Bitlist.extract (get_signed r t) i j - ) Bitlist.empty (Shostak.Bitv.embed r) - - let subst ex rr nrr t = - match MX.find rr t.bitlists with - | bl -> - (* The substitution code for constraints requires that we properly update - the [changed] field here: if the domain of [rr] has changed, - constraints that applied to [rr] will apply to [nrr] after - substitution and must be propagated again. *) - let changed = - if SX.mem rr t.changed then - SX.add nrr t.changed - else - t.changed - in - let t = - { changed = SX.remove rr changed - ; bitlists = MX.remove rr t.bitlists - } - in - update ex nrr t bl - | exception Not_found -> t - - let choose_changed t = - let r = SX.choose t.changed in - r, { t with changed = SX.remove r t.changed } - - let fold f t = MX.fold f t.bitlists + | Bitv.Cte z -> exact sz z Ex.empty + | Other r -> map_signed f r acc + | Ext (r, _r_size, i, j) -> extract (map_signed f r acc) i j + ) empty (Shostak.Bitv.embed r) + + let unknown = function + | Ty.Tbitv n -> unknown n Ex.empty + | _ -> + (* Only bit-vector values can have bitlist domains. *) + invalid_arg "unknown" end +module Domains = Rel_utils.Domains_make(Domain) + module Constraint : sig include Rel_utils.Constraint @@ -262,9 +144,6 @@ module Constraint : sig val bvxor : X.r -> X.r -> X.r -> t (** [bvxor x y z] is the constraint [x ^ y ^ z = 0] *) - val bvnot : X.r -> X.r -> t - (** [bvnot x y] is the constraint [x = not y] *) - val propagate : ex:Ex.t -> t -> Domains.t -> Domains.t (** [propagate ~ex t dom] propagates the constraint [t] in domain [dom]. @@ -278,8 +157,6 @@ end = struct (** [Bor (x, y, z)] represents [x = y | z] *) | Bxor of SX.t (** [Bxor {x1, ..., xn}] represents [x1 ^ ... ^ xn = 0] *) - | Bnot of X.r * X.r - (** [Bnot (x, y)] represents [x = not y] *) let normalize_repr = function | Band (x, y, z) when X.hash_cmp y z > 0 -> Band (x, z, y) @@ -292,19 +169,15 @@ end = struct | Bor (x1, y1, z1), Bor (x2, y2, z2) -> X.equal x1 x2 && X.equal y1 y2 && X.equal z1 z2 | Bxor xs1, Bxor xs2 -> SX.equal xs1 xs2 - | Bnot (x1, y1), Bnot (x2, y2) -> - (X.equal x1 x2 && X.equal y1 y2) | Band _, _ | Bor _, _ - | Bxor _, _ - | Bnot _, _ -> false + | Bxor _, _ -> false let hash_repr = function | Band (x, y, z) -> Hashtbl.hash (0, X.hash x, X.hash y, X.hash z) | Bor (x, y, z) -> Hashtbl.hash (1, X.hash x, X.hash y, X.hash z) | Bxor xs -> Hashtbl.hash (2, SX.fold (fun r acc -> X.hash r :: acc) xs []) - | Bnot (x, y) -> Hashtbl.hash (2, X.hash x, X.hash y) type t = { repr : repr ; mutable tag : int } @@ -336,8 +209,6 @@ end = struct | Bxor xs -> Fmt.(iter ~sep:(any " ^@ ") SX.iter X.print |> box) ppf xs; Fmt.pf ppf " = 0" - | Bnot (x, y) -> - Fmt.pf ppf "%a = ~%a" X.print x X.print y (* TODO: for bitwise constraints (eg Band, Bor, Bxor) on initialisation and also after substitution @@ -362,10 +233,6 @@ end = struct if SX.mem r xs then SX.remove r xs else SX.add r xs ) xs SX.empty ) - | Bnot (x, y) -> - let x = X.subst rr nrr x - and y = X.subst rr nrr y in - Bnot (x, y) let pp ppf { repr; _ } = pp_repr ppf repr @@ -382,10 +249,6 @@ end = struct let acc = f z acc in acc | Bxor xs -> SX.fold f xs acc - | Bnot (x, y) -> - let acc = f x acc in - let acc = f y acc in - acc let propagate ~ex { repr; _ } dom = Steps.incr CP; @@ -395,18 +258,22 @@ end = struct and dy = Domains.get y dom and dz = Domains.get z dom in - let dom = Domains.update ex x dom @@ Bitlist.logand dy dz in + let dom = + Domains.update x Bitlist.(add_explanation ~ex (logand dy dz)) dom + in (* Reverse propagation for y: if [x = y & z] then: - Any [1] in [x] must be a [1] in [y] - Any [0] in [x] that is also a [1] in [z] must be a [0] in [y] *) let dom = - Domains.update ex y dom @@ - Bitlist.(intersect (ones dx) (logor (zeroes dx) (lognot (ones dz))) ex) + Domains.update y Bitlist.( + intersect ~ex (ones dx) (logor (zeroes dx) (lognot (ones dz))) + ) dom in let dom = - Domains.update ex z dom @@ - Bitlist.(intersect (ones dx) (logor (zeroes dx) (lognot (ones dy))) ex) + Domains.update z Bitlist.( + intersect ~ex (ones dx) (logor (zeroes dx) (lognot (ones dy))) + ) dom in dom | Bor (x, y, z) -> @@ -414,20 +281,22 @@ end = struct and dy = Domains.get y dom and dz = Domains.get z dom in - let dom = Domains.update ex x dom @@ Bitlist.logor dy dz in + let dom = + Domains.update x Bitlist.(add_explanation ~ex (logor dy dz)) dom + in (* Reverse propagation for y: if [x = y | z] then: - Any [0] in [x] must be a [0] in [y] - Any [1] in [x] that is also a [0] in [z] must be a [1] in [y] *) let dom = - Domains.update ex y dom @@ Bitlist.( - intersect (zeroes dx) (logand (ones dx) (lognot (zeroes dz))) ex - ) + Domains.update y Bitlist.( + intersect ~ex (zeroes dx) (logand (ones dx) (lognot (zeroes dz))) + ) dom in let dom = - Domains.update ex z dom @@ Bitlist.( - intersect (zeroes dx) (logand (ones dx) (lognot (zeroes dy))) ex - ) + Domains.update z Bitlist.( + intersect ~ex (zeroes dx) (logand (ones dx) (lognot (zeroes dy))) + ) dom in dom | Bxor xs -> @@ -441,13 +310,8 @@ end = struct Bitlist.logxor (Domains.get y dom) acc ) xs (Bitlist.exact (Bitlist.width dx) Z.zero Ex.empty) in - Domains.update ex x dom dx' + Domains.update x (Bitlist.add_explanation ~ex dx') dom ) xs dom - | Bnot (x, y) -> - let dx = Domains.get x dom and dy = Domains.get y dom in - let dom = Domains.update ex x dom @@ Bitlist.lognot dy in - let dom = Domains.update ex y dom @@ Bitlist.lognot dx in - dom let bvand x y z = hcons @@ Band (x, y, z) let bvor x y z = hcons @@ Bor (x, y, z) @@ -456,18 +320,12 @@ end = struct let xs = if SX.mem y xs then SX.remove y xs else SX.add y xs in let xs = if SX.mem z xs then SX.remove z xs else SX.add z xs in hcons @@ Bxor xs - let bvnot x y = hcons @@ Bnot (x, y) end module Constraints = Rel_utils.Constraints_make(Constraint) let extract_constraints bcs uf r t = match E.term_view t with - (* BVnot is already internalized in the Shostak but we want to know about it - without needing a round-trip through Uf *) - | { f = Op BVnot; xs = [ x ] ; _ } -> - let rx, exx = Uf.find uf x in - Constraints.add ~ex:exx (Constraint.bvnot r rx) bcs | { f = Op BVand; xs = [ x; y ]; _ } -> let rx, exx = Uf.find uf x and ry, exy = Uf.find uf y in @@ -545,15 +403,15 @@ let propagate = | exception Not_found -> match Domains.choose_changed dom with | r, dom -> - propagate (SX.add r changed) (Constraints.notify_leaf r bcs) dom + propagate (SX.add r changed) (Constraints.notify r bcs) dom | exception Not_found -> changed, bcs, dom in - fun bcs dom -> + fun eqs bcs dom -> let changed, bcs, dom = propagate SX.empty bcs dom in SX.fold (fun r acc -> add_eqs acc (Shostak.Bitv.embed r) (Domains.get r dom) - ) changed [], bcs, dom + ) changed eqs, bcs, dom type t = { delayed : Rel_utils.Delayed.t @@ -571,8 +429,8 @@ let assume env uf la = let delayed, result = Rel_utils.Delayed.assume env.delayed uf la in let (domain, constraints, eqs, size_splits) = try - let ((constraints, domain), size_splits) = - List.fold_left (fun ((bcs, dom), ss) (a, _root, ex, orig) -> + let ((constraints, domain), eqs, size_splits) = + List.fold_left (fun ((bcs, dom), eqs, ss) (a, _root, ex, orig) -> let ss = match orig with | Th_util.CS (Th_bitv, n) -> Q.(ss * n) @@ -585,31 +443,25 @@ let assume env uf la = in match a, orig with | L.Eq (rr, nrr), Subst when is_bv_r rr -> - let dom = Domains.subst ex rr nrr dom in + let dom = Domains.subst ~ex rr nrr dom in let bcs = Constraints.subst ~ex rr nrr bcs in - ((bcs, dom), ss) + ((bcs, dom), eqs, ss) | L.Distinct (false, [rr; nrr]), _ when is_1bit rr -> (* We don't (yet) support [distinct] in general, but we must support it for case splits to avoid looping. We are a bit more general and support it for 1-bit vectors, for - which `distinct` can be expressed using `bvnot`. - - Note that we are not guaranteed that the arguments are already - in normal form! *) - let rr, exrr = Uf.find_r uf rr in - let nrr, exnrr = Uf.find_r uf nrr in - let ex = Ex.union ex (Ex.union exrr exnrr) in - let bcs = - Constraints.add ~ex (Constraint.bvnot rr nrr) bcs + which `distinct` can be expressed using `bvnot`. *) + let not_nrr = + Shostak.Bitv.is_mine (Bitv.lognot (Shostak.Bitv.embed nrr)) in - ((bcs, dom), ss) - | _ -> ((bcs, dom), ss) + ((bcs, dom), (Uf.LX.mkv_eq rr not_nrr, ex) :: eqs, ss) + | _ -> ((bcs, dom), eqs, ss) ) - ((env.constraints, env.domain), env.size_splits) + ((env.constraints, env.domain), [], env.size_splits) la in - let eqs, constraints, domain = propagate constraints domain in + let eqs, constraints, domain = propagate eqs constraints domain in if Options.get_debug_bitv () && not (Lists.is_empty eqs) then ( Printer.print_dbg ~module_name:"Bitv_rel" ~function_name:"assume" @@ -667,7 +519,7 @@ let case_split env _uf ~for_model = with | Some (nunk, xs) -> nunk, xs | None -> - match Domains.fold f_acc env.domain None with + match Domains.fold_leaves f_acc env.domain None with | Some (nunk, xs) -> nunk, xs | None -> 0, SX.empty in @@ -695,15 +547,13 @@ let add env uf r t = let delayed, eqs = Rel_utils.Delayed.add env.delayed uf r t in let env, eqs = match X.type_info r with - | Tbitv n -> ( + | Tbitv _ -> ( try - let dr = Bitlist.unknown n Ex.empty in - let dom = Domains.update Ex.empty r env.domain dr in + let dom = Domains.add r env.domain in let bcs = extract_constraints env.constraints uf r t in - let eqs', bcs, dom = propagate bcs dom in - { env with constraints = bcs ; domain = dom }, - List.rev_append eqs' eqs - with Bitlist.Inconsistent ex -> + let eqs, bcs, dom = propagate eqs bcs dom in + { env with constraints = bcs ; domain = dom }, eqs + with Domains.Inconsistent ex -> raise @@ Ex.Inconsistent (ex, Uf.cl_extract uf) ) | _ -> env, eqs diff --git a/src/lib/reasoners/rel_utils.ml b/src/lib/reasoners/rel_utils.ml index 45e55d555..c90158be2 100644 --- a/src/lib/reasoners/rel_utils.ml +++ b/src/lib/reasoners/rel_utils.ml @@ -15,7 +15,8 @@ (**************************************************************************) module X = Shostak.Combine -module MXH = Shostak.MXH +module MX = Shostak.MXH +module SX = Shostak.SXH module L = Xliteral module LR = Uf.LX module SR = Set.Make( @@ -121,10 +122,10 @@ end = struct type t = { dispatch : Symbols.operator -> delayed_fn option ; - used_by : Expr.Set.t OMap.t MXH.t ; + used_by : Expr.Set.t OMap.t MX.t ; } - let create dispatch = { dispatch; used_by = MXH.empty } + let create dispatch = { dispatch; used_by = MX.empty } let add ({ dispatch; used_by } as env) uf r t = (* Note: we dispatch on [Op] symbols, but it could make sense to dispatch on @@ -144,7 +145,7 @@ end = struct | None -> let used_by = List.fold_left (fun used_by x -> - MXH.update (Uf.make uf x) (fun sm -> + MX.update (Uf.make uf x) (fun sm -> let sm = Option.value ~default:OMap.empty sm in Option.some @@ OMap.update f (fun se -> @@ -155,7 +156,7 @@ end = struct | _ -> env, [] let update { dispatch; used_by } uf r1 eqs = - match MXH.find r1 used_by with + match MX.find r1 used_by with | exception Not_found -> eqs | sm -> OMap.fold (fun sy se eqs -> @@ -193,6 +194,307 @@ end = struct env, { Sig_rel.assume = assume_nontrivial_eqs eqs la; remove = [] } end +module type Domain = sig + type t + (** The type of domains for a single value. + + This is an abstract type that is instanciated by the theory. Note that + it is expected that this type can carry explanations. *) + + val equal : t -> t -> bool + (** [equal d1 d2] returns [true] if the domains [d1] and [d2] are + identical. Explanations should not be taken into consideration, i.e. + two domains with different explanations but identical semantics content + should compare equal. *) + + val pp : t Fmt.t + (** Pretty-printer for domains. *) + + exception Inconsistent of Explanation.t + (** Exception raised by [intersect] when an inconsistency is detected. *) + + val unknown : Ty.t -> t + (** [unknown ty] returns a full domain for values of type [t]. *) + + val add_explanation : ex:Explanation.t -> t -> t + (** [add_explanation ~ex d] adds the justification [ex] to the domain d. The + returned domain is identical to the domain of [d], only the + justifications are changed. *) + + val intersect : ex:Explanation.t -> t -> t -> t + (** [intersect ~ex d1 d2] returns a new domain [d] that subsumes both [d1] + and [d2]. The explanation [ex] justifies that the two domains can be + merged. + + @raise Inconsistent if [d1] and [d2] are not compatible (the + intersection would be empty). The justification always includes the + justification [ex] used for the intersection. *) + + + val fold_leaves : (X.r -> t -> 'a -> 'a) -> X.r -> t -> 'a -> 'a + (** [fold_leaves f r t acc] folds [f] over the leaves of [r], deconstructing + the domain [t] according to the structure of [r]. + + It is assumed that [t] already contains any justification required for + it to apply to [r]. + + @raise Inconsistent if [r] cannot possibly be in the domain of [t]. *) + + val map_leaves : (X.r -> 'a -> t) -> X.r -> 'a -> t + (** [map_leaves f r acc] is the "inverse" of [fold_leaves] in the sense that + it rebuilds a domain for [r] by using [f] to access the domain for each + of [r]'s leaves. *) +end + +module type Domains = sig + type t + (** The type of domain maps. A domain map maps each representative (semantic + value, of type [X.r]) to its associated domain.*) + + val pp : t Fmt.t + (** Pretty-printer for domain maps. *) + + val empty : t + (** The empty domain map. *) + + type elt + (** The type of domains contained in the map. Each domain of type [elt] + applies to a single semantic value. *) + + exception Inconsistent of Explanation.t + (** Exception raised by [update], [subst] and [structural_propagation] when + an inconsistency is detected. *) + + val add : X.r -> t -> t + (** [add r t] adds a domain for [r] in the domain map. If [r] does not + already have an associated domain, a fresh domain will be created for + [r]. *) + + val get : X.r -> t -> elt + (** [get r t] returns the domain currently associated with [r] in [t]. *) + + val update : X.r -> elt -> t -> t + (** [update r d t] intersects the domain of [r] in [t] with the domain [d]. + + {b Soundness}: The domain [d] must already include the justification + that it applies to [r]. + + @raise Inconsistent if this causes the domain associated with [r] to + become empty. *) + + val fold_leaves : (X.r -> elt -> 'a -> 'a) -> t -> 'a -> 'a + (** [fold f t acc] folds [f] over all the domains in [t] that are associated + with leaves. *) + + val subst : ex:Explanation.t -> X.r -> X.r -> t -> t + (** [subst ~ex p v d] replaces all the instances of [p] with [v] in all + domains, merging the corresponding domains as appropriate. + + The explanation [ex] justifies the equality [p = v]. + + @raise Inconsistent if this causes any domain in [d] to become empty. *) + + val choose_changed : t -> X.r * t + (** [choose_changed d] returns a pair [r, d'] such that: + + - The domain associated with [r] has changed since the last time + [choose_changed] was called. + - [r] has (by definition) not changed in [d'] + + Moreover, prior to returning [r], structural propagation is + automatically performed. + + More precisely, if [r] is a leaf, the domain of [r] is propagated to any + semantic value that contains [r] as a leaf according to the structure of + that semantic value (using [Domain.map_leaves]); if [r] is not a leaf, + its domain is propagated to any of the leaves it contains. + + We only perform *forward* structural propagation: if structural + propagation causes a domain of a leaf or parent to be changed, then we + will perform structural propagation for that leaf or parent once it + itself is selected by [choose_changed]. + + @raise Inconsistent if an inconsistency if detected during structural + propagation. *) +end + +module Domains_make(Domain : Domain) : Domains with type elt = Domain.t = +struct + type elt = Domain.t + + exception Inconsistent = Domain.Inconsistent + + type t = { + domains : Domain.t MX.t ; + (** Map from tracked representatives to their domain *) + + changed : SX.t ; + (** Representatives whose domain has changed since the last flush *) + + leaves_map : SX.t MX.t ; + (** Map from leaves to the *tracked* representatives that contains them *) + } + + let pp ppf t = + Fmt.(iter_bindings ~sep:semi MX.iter + (box @@ pair ~sep:(any " ->@ ") X.print Domain.pp) + |> braces + ) + ppf t.domains + + let empty = + { domains = MX.empty ; changed = SX.empty ; leaves_map = MX.empty } + + let r_add r leaves_map = + List.fold_left (fun leaves_map leaf -> + MX.update leaf (function + | Some parents -> Some (SX.add r parents) + | None -> Some (SX.singleton r) + ) leaves_map + ) leaves_map (X.leaves r) + + let create_domain r = + Domain.map_leaves (fun r () -> + Domain.unknown (X.type_info r) + ) r () + + let add r t = + match MX.find r t.domains with + | _ -> t + | exception Not_found -> + (* Note: we do not need to mark [r] as needing propagation, because no + constraints applied to it yet. Any constraint that apply to [r] will + already be marked as pending due to being newly added. *) + let d = create_domain r in + let domains = MX.add r d t.domains in + let leaves_map = r_add r t.leaves_map in + { t with domains; leaves_map } + + let r_remove r leaves_map = + List.fold_left (fun leaves_map leaf -> + MX.update leaf (function + | Some parents -> + let parents = SX.remove r parents in + if SX.is_empty parents then None else Some parents + | None -> None + ) leaves_map + ) leaves_map (X.leaves r) + + let remove r t = + let changed = SX.remove r t.changed in + let domains = MX.remove r t.domains in + let leaves_map = r_remove r t.leaves_map in + { changed; domains; leaves_map } + + let get r t = + (* We need to catch [Not_found] because of fresh terms that can be added + by the solver and for which we don't call [add]. Note that in this + case, only structural constraints can apply to [r]. *) + try MX.find r t.domains with Not_found -> create_domain r + + let update r d t = + match MX.find r t.domains with + | od -> + (* Both domains are already valid for [r], we can intersect them + without additional justifications. *) + let d = Domain.intersect ~ex:Explanation.empty od d in + if Domain.equal od d then + t + else + let domains = MX.add r d t.domains in + let changed = SX.add r t.changed in + { t with domains; changed } + | exception Not_found -> + (* We need to catch [Not_found] because of fresh terms that can be added + by the solver and for which we don't call [add]. *) + let d = Domain.intersect ~ex:Explanation.empty d (create_domain r) in + let domains = MX.add r d t.domains in + let changed = SX.add r t.changed in + let leaves_map = r_add r t.leaves_map in + { domains; changed; leaves_map } + + let fold_leaves f t acc = + MX.fold (fun r _ acc -> + f r (get r t) acc + ) t.leaves_map acc + + let subst ~ex rr nrr t = + match MX.find rr t.leaves_map with + | parents -> + SX.fold (fun r t -> + let d = + try MX.find r t.domains + with Not_found -> + (* [r] was in the [leaves_map] to it must have a domain *) + assert false + in + let changed = SX.mem r t.changed in + let t = remove r t in + let nr = X.subst rr nrr r in + match MX.find nr t.domains with + | nd -> + (* If there is an existing domain for [nr], there might be + constraints applying to [nr] prior to the substitution, and the + constraints that used to apply to [r] will also apply to [nr] + after the substitution. + + We need to notify changed to either of these constraints, so we + must notify if the domain is different from *either* the old + domain of [r] or the old domain of [nr]. *) + let nnd = Domain.intersect ~ex d nd in + let nr_changed = not (Domain.equal nnd nd) in + let r_changed = not (Domain.equal nnd d) in + let domains = + if nr_changed then MX.add nr nnd t.domains else t.domains + in + let changed = changed || r_changed || nr_changed in + let changed = + if changed then SX.add nr t.changed else t.changed + in + { t with domains; changed } + | exception Not_found -> + (* If there is no existing domain for [nr], there were no + constraints applying to [nr] prior to the substitution. + + The only constraints that need to be notified are those that + were applying to [r], and they only need to be notified if the + new domain is different from the old domain of [r]. *) + let nd = Domain.intersect ~ex d (create_domain nr) in + let r_changed = not (Domain.equal nd d) in + let domains = MX.add nr nd t.domains in + let leaves_map = r_add nr t.leaves_map in + let changed = changed || r_changed in + let changed = + if changed then SX.add nr t.changed else t.changed + in + { domains; changed; leaves_map } + ) parents t + | exception Not_found -> + (* We are not tracking any semantic value that have [r] as a leaf, so we + have nothing to do. *) + t + + let structural_propagation r t = + if X.is_a_leaf r then + match MX.find r t.leaves_map with + | parents -> + SX.fold (fun parent t -> + if X.is_a_leaf parent then ( + assert (X.equal r parent); + t + ) else + update parent (Domain.map_leaves get parent t) t + ) parents t + | exception Not_found -> t + else + Domain.fold_leaves update r (get r t) t + + let choose_changed t = + let r = SX.choose t.changed in + let t = { t with changed = SX.remove r t.changed } in + r, structural_propagation r t +end + module type Constraint = sig type t (** The type of constraints. @@ -288,8 +590,6 @@ module Constraints_make(Constraint : Constraint) : sig @raise Not_found if there are no pending constraints. *) end = struct - module MX = Shostak.MXH - module CS = Set.Make(struct type t = Constraint.t explained