diff --git a/src/lib/reasoners/bitv_rel.ml b/src/lib/reasoners/bitv_rel.ml index 15761cc24..0bfc97369 100644 --- a/src/lib/reasoners/bitv_rel.ml +++ b/src/lib/reasoners/bitv_rel.ml @@ -33,7 +33,6 @@ module Ex = Explanation module Sy = Symbols module X = Shostak.Combine module L = Xliteral -module Congruence = Rel_utils.Congruence (* Currently we only compute, but in the future we may want to perform the same simplifications as in [Bitv.make]. We currently don't, because we don't @@ -244,59 +243,29 @@ end = struct end module Constraint : sig - type repr = - | Band of X.r * X.r * X.r - (** [Band (x, y, z)] represents [x = y & z] *) - | Bor of X.r * X.r * X.r - (** [Bor (x, y, z)] represents [x = y | z] *) - | Bxor of SX.t - (** [Bxor {x1, ..., xn}] represents [x1 ^ ... ^ xn = 0] *) - | Bnot of X.r * X.r - (** [Bnot (x, y)] represents [x = not y] *) - - type tagged_repr - - val hcons : repr -> tagged_repr - (** Internalize the constraint representation. - - This uses hash-consing and some simple normalization to de-duplicate - constraints. *) + include Rel_utils.Constraint with type domain = Domains.t - val tag : tagged_repr -> int - (** Returns the unique tag associated with the tagged repr. *) + val bvand : ex:Ex.t -> X.r -> X.r -> X.r -> t + (** [bvand ~ex x y z] is the constraint [x = y & z] *) - type t = { repr : tagged_repr ; ex : Ex.t } - (** The type of bit-vector constraints. - - Bit-vector constraints contains semantic values / term representatives of - type [X.r]. We maintain the invariant that the semantic values used inside - the constraints are *class representatives* i.e. normal forms wrt the `Uf` - module, i.e. constraints have a normalized representation. Use `subst` to - ensure normalization. *) + val bvor : ex:Ex.t -> X.r -> X.r -> X.r -> t + (** [bvor ~ex x y z] is the constraint [x = y | z] *) - val pp : t Fmt.t - (** Pretty-printer for constraints. *) + val bvxor : ex:Ex.t -> X.r -> X.r -> X.r -> t + (** [bvxor ~ex x y z] is the constraint [x ^ y ^ z = 0] *) - val subst : Ex.t -> X.r -> X.r -> t -> t - (** [subst ex p v cs] replaces all the instaces of [p] with [v] in the - constraint. - - Use this to ensure that the representation is always normalized. - - The explanation [ex] justifies the equality [p = v]. *) - - val fold_deps : (X.r -> 'a -> 'a) -> t -> 'a -> 'a - (** [fold_deps f c acc] accumulates [f] over the arguments of [c]. *) - - val propagate : t -> Domains.t -> Domains.t - (** [propagate c dom] propagates the constraints [c] in [d] and returns the - new domains. *) + val bvnot : ex:Ex.t -> X.r -> X.r -> t + (** [bvnot ~ex x y] is the constraint [x = not y] *) end = struct type repr = | Band of X.r * X.r * X.r + (** [Band (x, y, z)] represents [x = y & z] *) | Bor of X.r * X.r * X.r + (** [Bor (x, y, z)] represents [x = y | z] *) | Bxor of SX.t + (** [Bxor {x1, ..., xn}] represents [x1 ^ ... ^ xn = 0] *) | Bnot of X.r * X.r + (** [Bnot (x, y)] represents [x = not y] *) let normalize_repr = function | Band (x, y, z) when X.hash_cmp y z > 0 -> Band (x, z, y) @@ -345,8 +314,6 @@ end = struct ); tagged - let tag { tag; _ } = tag - let pp_repr ppf = function | Band (x, y, z) -> Fmt.pf ppf "%a & %a = %a" X.print y X.print z X.print x @@ -365,25 +332,25 @@ end = struct let subst_repr rr nrr = function | Band (x, y, z) -> - let x = if X.equal x rr then nrr else x - and y = if X.equal y rr then nrr else y - and z = if X.equal z rr then nrr else z in + let x = X.subst rr nrr x + and y = X.subst rr nrr y + and z = X.subst rr nrr z in Band (x, y, z) | Bor (x, y, z) -> - let x = if X.equal x rr then nrr else x - and y = if X.equal y rr then nrr else y - and z = if X.equal z rr then nrr else z in + let x = X.subst rr nrr x + and y = X.subst rr nrr y + and z = X.subst rr nrr z in Bor (x, y, z) | Bxor xs -> Bxor ( SX.fold (fun r xs -> - let r = if X.equal r rr then nrr else r in + let r = X.subst rr nrr r in if SX.mem r xs then SX.remove r xs else SX.add r xs ) xs SX.empty ) | Bnot (x, y) -> - let x = if X.equal x rr then nrr else x - and y = if X.equal y rr then nrr else y in + let x = X.subst rr nrr x + and y = X.subst rr nrr y in Bnot (x, y) (* The explanation justifies why the constraint holds. *) @@ -391,6 +358,9 @@ end = struct let pp ppf { repr; _ } = pp_repr ppf repr.repr + let compare { repr = r1; _ } { repr = r2; _ } = + Int.compare r1.tag r2.tag + let subst ex rr nrr c = { repr = hcons @@ subst_repr rr nrr c.repr.repr ; ex = Ex.union ex c.ex } @@ -407,6 +377,13 @@ end = struct let acc = f y acc in acc + let fold_leaves f c acc = + fold_deps (fun r acc -> + List.fold_left (fun acc r -> f r acc) acc (X.leaves r) + ) c acc + + type domain = Domains.t + let propagate { repr; ex } dom = Steps.incr CP; match repr.repr with @@ -468,177 +445,41 @@ end = struct let dom = Domains.update ex x dom @@ Bitlist.lognot dy in let dom = Domains.update ex y dom @@ Bitlist.lognot dx in dom -end - -module Constraints : sig - type t - (** The type of constraint sets. A constraint sets records a set of - constraints that applies to semantic values, and remembers which - constraints are associated with each semantic values. - - It is used to only propagate constraints involving semantic values whose - associated domain has changed. - - The constraint sets are expected to keep track of *class representatives*, - i.e. normal forms wrt the `Uf` module, in which case we say the - constraint set is *normalized*. Use `subst` to ensure normalization. *) - - val pp : t Fmt.t - (** Pretty-printer for constraint sets. *) - - val empty : t - (** Returns an empty constraint set. *) - - val subst : Ex.t -> X.r -> X.r -> t -> t - (** [subst ex p v cs] replaces all the instances of [p] with [v] in the - constraints. - - Use this to ensure that the representation is always normalized. - - The explanation [ex] justifies the equality [p = v]. *) - val add : t -> Constraint.t -> t - (** [add c cs] adds the constraint [c] to [cs]. *) + let make ?(ex = Ex.empty) repr = { repr = hcons repr ; ex } - val fold_fresh : (Constraint.t -> 'a -> 'a) -> t -> 'a -> t * 'a - (** [fold_fresh f cs acc] folds [f] over the fresh constraints in [cs]. - - Fresh constraints are constraints that were never propagated yet. *) - - val fold_r : (X.r -> 'a -> 'a) -> t -> 'a -> 'a - (** [fold_r f cs acc] folds [f] over any representative [r] that is currently - associated with a constraint (i.e. at least one constraint currently - applies to [r]). *) - - val propagate : t -> X.r -> Domains.t -> Domains.t - (** [propagate cs r dom] propagates the constraints associated with [r] in the - constraint set [cs] and returns the new domain map after propagation. *) -end = struct - module IM = Util.MI - - module CS = Set.Make(struct - type t = Constraint.t - - let compare t1 t2 = - Int.compare Constraint.(tag t1.repr) Constraint.(tag t2.repr) - end) - - type t = { - cs_set : CS.t ; - (*** All the constraints currently active *) - cs_map : CS.t MX.t ; - (*** Mapping from semantic values to the constraints that involves them *) - fresh : CS.t ; - (*** Fresh constraints that have never been propagated *) - } - - let pp ppf { cs_set; cs_map = _ ; fresh = _ } = - Fmt.( - braces @@ hvbox @@ - iter ~sep:semi CS.iter @@ - box ~indent:2 @@ braces @@ - Constraint.pp - ) ppf cs_set - - let empty = - { cs_set = CS.empty - ; cs_map = MX.empty - ; fresh = CS.empty } - - let cs_add cs r cs_map = - MX.update r (function - | Some css -> Some (CS.add cs css) - | None -> Some (CS.singleton cs) - ) cs_map - - let cs_remove cs r cs_map = - MX.update r (function - | Some css -> - let css = CS.remove cs css in - if CS.is_empty css then None else Some css - | None -> - (* Can happen if the same argument is repeated *) - None - ) cs_map - - let subst ex rr nrr bcs = - match MX.find rr bcs.cs_map with - | ids -> - let cs_map, cs_set, fresh = - CS.fold (fun cs (cs_map, cs_set, fresh) -> - let fresh = CS.remove cs fresh in - let cs_set = CS.remove cs cs_set in - let cs_map = Constraint.fold_deps (cs_remove cs) cs cs_map in - let cs' = Constraint.subst ex rr nrr cs in - if CS.mem cs' cs_set then - cs_map, cs_set, fresh - else - let cs_set = CS.add cs' cs_set in - let cs_map = Constraint.fold_deps (cs_add cs') cs' cs_map in - (cs_map, cs_set, CS.add cs' fresh) - ) ids (bcs.cs_map, bcs.cs_set, bcs.fresh) - in - assert (not (MX.mem rr cs_map)); - { cs_set ; cs_map ; fresh } - | exception Not_found -> bcs - - let add bcs c = - if CS.mem c bcs.cs_set then - bcs - else - let cs_set = CS.add c bcs.cs_set in - let cs_map = - Constraint.fold_deps (cs_add c) c bcs.cs_map - in - let fresh = CS.add c bcs.fresh in - { cs_set ; cs_map ; fresh } - - let fold_fresh f bcs acc = - let acc = CS.fold f bcs.fresh acc in - { bcs with fresh = CS.empty }, acc - - let fold_r f bcs acc = - MX.fold (fun r _ acc -> f r acc) bcs.cs_map acc - - let propagate bcs r dom = - match MX.find r bcs.cs_map with - | cs -> CS.fold Constraint.propagate cs dom - | exception Not_found -> dom + let bvand ~ex x y z = make ~ex @@ Band (x, y, z) + let bvor ~ex x y z = make ~ex @@ Bor (x, y, z) + let bvxor ~ex x y z = + let xs = SX.singleton x in + let xs = if SX.mem y xs then SX.remove y xs else SX.add y xs in + let xs = if SX.mem z xs then SX.remove z xs else SX.add z xs in + make ~ex @@ Bxor xs + let bvnot ~ex x y = make ~ex @@ Bnot (x, y) end -(* Add one constraint and register its arguments as relevant for congruence *) -let add_constraint (bcs, cgr) c = - let bcs = Constraints.add bcs c in - let cgr = Constraint.fold_deps Congruence.add c cgr in - (bcs, cgr) +module Constraints = Rel_utils.Constraints_Make(Constraint) -let extract_constraints (bcs, cgr) uf r t = +let extract_constraints bcs uf r t = match E.term_view t with (* BVnot is already internalized in the Shostak but we want to know about it without needing a round-trip through Uf *) | { f = Op BVnot; xs = [ x ] ; _ } -> let rx, exx = Uf.find uf x in - add_constraint (bcs, cgr) @@ - { repr = Constraint.hcons @@ Bnot (r, rx) ; ex = exx } + Constraints.add bcs @@ Constraint.bvnot ~ex:exx r rx | { f = Op BVand; xs = [ x; y ]; _ } -> let rx, exx = Uf.find uf x and ry, exy = Uf.find uf y in - add_constraint (bcs, cgr) @@ - { repr = Constraint.hcons @@ Band (r, rx, ry) ; ex = Ex.union exx exy } + Constraints.add bcs @@ Constraint.bvand ~ex:(Ex.union exx exy) r rx ry | { f = Op BVor; xs = [ x; y ]; _ } -> let rx, exx = Uf.find uf x and ry, exy = Uf.find uf y in - add_constraint (bcs, cgr) @@ - { repr = Constraint.hcons @@ Bor (r, rx, ry) ; ex = Ex.union exx exy } + Constraints.add bcs @@ Constraint.bvor ~ex:(Ex.union exx exy) r rx ry | { f = Op BVxor; xs = [ x; y ]; _ } -> let rx, exx = Uf.find uf x and ry, exy = Uf.find uf y in - let xs = SX.singleton r in - let xs = if SX.mem rx xs then SX.remove rx xs else SX.add rx xs in - let xs = if SX.mem ry xs then SX.remove ry xs else SX.add ry xs in - add_constraint (bcs, cgr) @@ - { repr = Constraint.hcons @@ Bxor xs ; ex = Ex.union exx exy } - | _ -> (bcs, cgr) + Constraints.add bcs @@ Constraint.bvxor ~ex:(Ex.union exx exy) r rx ry + | _ -> bcs let rec mk_eq ex lhs w z = match lhs with @@ -693,19 +534,16 @@ let add_eqs = includes constraints that changed due to substitutions) - The constraints involving variables whose domain changed since the last propagation *) -let propagate cgr = +let propagate = let rec propagate changed bcs dom = match Domains.choose_changed dom with - | r, dom -> ( - propagate (SX.add r changed) bcs @@ - Congruence.fold_parents (Constraints.propagate bcs) cgr r dom - ) + | r, dom -> + propagate (SX.add r changed) bcs @@ + Constraints.propagate bcs r dom | exception Not_found -> changed, dom in fun bcs dom -> - let bcs, dom = - Constraints.fold_fresh Constraint.propagate bcs dom - in + let bcs, dom = Constraints.propagate_fresh bcs dom in let changed, dom = propagate SX.empty bcs dom in SX.fold (fun r acc -> add_eqs acc (Shostak.Bitv.embed r) (Domains.get r dom) @@ -715,22 +553,20 @@ type t = { delayed : Rel_utils.Delayed.t ; domain : Domains.t ; constraints : Constraints.t - ; congruence : Congruence.t ; size_splits : Q.t } let empty _ = { delayed = Rel_utils.Delayed.create dispatch ; domain = Domains.empty ; constraints = Constraints.empty - ; congruence = Congruence.empty ; size_splits = Q.one } let assume env uf la = let delayed, result = Rel_utils.Delayed.assume env.delayed uf la in - let (congruence, domain, constraints, eqs, size_splits) = + let (domain, constraints, eqs, size_splits) = try - let (congruence, (constraints, domain), size_splits) = - List.fold_left (fun (cgr, (bcs, dom), ss) (a, _root, ex, orig) -> + let ((constraints, domain), size_splits) = + List.fold_left (fun ((bcs, dom), ss) (a, _root, ex, orig) -> let ss = match orig with | Th_util.CS (Th_bitv, n) -> Q.(ss * n) @@ -744,10 +580,8 @@ let assume env uf la = match a, orig with | L.Eq (rr, nrr), Subst when is_bv_r rr -> let dom = Domains.subst ex rr nrr dom in - let cgr, bcs = - Congruence.subst rr nrr cgr (Constraints.subst ex) bcs - in - (cgr, (bcs, dom), ss) + let bcs = Constraints.subst ex rr nrr bcs in + ((bcs, dom), ss) | L.Distinct (false, [rr; nrr]), _ when is_1bit rr -> (* We don't (yet) support [distinct] in general, but we must support it for case splits to avoid looping. @@ -760,17 +594,16 @@ let assume env uf la = let rr, exrr = Uf.find_r uf rr in let nrr, exnrr = Uf.find_r uf nrr in let ex = Ex.union ex (Ex.union exrr exnrr) in - let bcs, cgr = - add_constraint (bcs, cgr) @@ - { repr = Constraint.hcons @@ Bnot (rr, nrr) ; ex } + let bcs = + Constraints.add bcs @@ Constraint.bvnot ~ex rr nrr in - (cgr, (bcs, dom), ss) - | _ -> (cgr, (bcs, dom), ss) + ((bcs, dom), ss) + | _ -> ((bcs, dom), ss) ) - (env.congruence, (env.constraints, env.domain), env.size_splits) + ((env.constraints, env.domain), env.size_splits) la in - let eqs, constraints, domain = propagate congruence constraints domain in + let eqs, constraints, domain = propagate constraints domain in if Options.get_debug_bitv () && not (Lists.is_empty eqs) then ( Printer.print_dbg ~module_name:"Bitv_rel" ~function_name:"assume" @@ -779,7 +612,7 @@ let assume env uf la = ~module_name:"Bitv_rel" ~function_name:"assume" "bitlist constraints: @[%a@]" Constraints.pp constraints; ); - (congruence, domain, constraints, eqs, size_splits) + (domain, constraints, eqs, size_splits) with Bitlist.Inconsistent ex -> raise @@ Ex.Inconsistent (ex, Uf.cl_extract uf) in @@ -789,7 +622,7 @@ let assume env uf la = let result = { result with assume = List.rev_append assume result.assume } in - { delayed ; constraints ; domain ; congruence ; size_splits }, result + { delayed ; constraints ; domain ; size_splits }, result let query _ _ _ = None @@ -860,11 +693,9 @@ let add env uf r t = try let dr = Bitlist.unknown n Ex.empty in let dom = Domains.update Ex.empty r env.domain dr in - let (bcs, congruence) = - extract_constraints (env.constraints, env.congruence) uf r t - in - let eqs', bcs, dom = propagate congruence bcs dom in - { env with congruence ; constraints = bcs ; domain = dom }, + let bcs = extract_constraints env.constraints uf r t in + let eqs', bcs, dom = propagate bcs dom in + { env with constraints = bcs ; domain = dom }, List.rev_append eqs' eqs with Bitlist.Inconsistent ex -> raise @@ Ex.Inconsistent (ex, Uf.cl_extract uf) diff --git a/src/lib/reasoners/rel_utils.ml b/src/lib/reasoners/rel_utils.ml index 370568371..f79d48d21 100644 --- a/src/lib/reasoners/rel_utils.ml +++ b/src/lib/reasoners/rel_utils.ml @@ -193,139 +193,186 @@ end = struct env, { Sig_rel.assume = assume_nontrivial_eqs eqs la; remove = [] } end -module Congruence : sig - (** The [Congruence] module implements a simil-congruence closure algorithm on - semantic values. +module type Constraint = sig + type t + (** The type of constraints. - It provides an interface to register some semantic values of interest, and - for applying a callback when the representative of those registered values - change. - *) + Constraints are associated with a justification as to why they are + currently valid. The justification is only used to update domains, + identical constraints with different justifications will otherwise behave + identically (and, notably, will compare equal). - type t - (** The type of congruences. *) + Constraints contains semantic values / term representatives of type + [X.r]. We maintain the invariant that the semantic values used inside the + constraints are *class representatives* i.e. normal forms wrt the `Uf` + module, i.e. constraints have a normalized representation. Use `subst` to + ensure normalization. *) - val empty : t - (** The empty congruence. *) + val pp : t Fmt.t + (** Pretty-printer for constraints. *) - val add : X.r -> t -> t - (** [add r t] registers the semantic value [r] in the congruence. *) + val compare : t -> t -> int + (** Comparison function for constraints. - val remove : X.r -> t -> t - (** [remove r t] unregisters the semantic value [r] from the congruence. + Constraints typically include explanations, which should not be included + in the comparison function: code working with constraints expects + constraints with identical representations but different explanations to + compare equal. - Note that if substitutions have been applied to the congruence after a - value has been added, those same substitutions must be applied to the - semantic value prior to calling [remove], or [Invalid_argument] will be - raised. + {b Note}: The comparison function is arbitrary and has no semantic + meaning. You should not depend on any of its properties, other than it + defines an (arbitrary) total order on constraint representations. *) - @raise [Invalid_argument] if [r] is not a registered semantic value. *) + val subst : Explanation.t -> X.r -> X.r -> t -> t + (** [subst ex p v cs] replaces all the instances of [p] with [v] in the + constraint. - val subst : X.r -> X.r -> t -> (X.r -> X.r -> 'a -> 'a) -> 'a -> t * 'a - (** [subst p v t f x] performs a local congruence closure of the - substitution [p -> v]. + Use this to ensure that the representation is always normalized. - More precisely, it will fold [f] over the pairs [(rr, nrr)] such that: - - [rr] was registered in the congruence - - [nrr] is [X.subst p v rr] + The explanation [ex] justifies the equality [p = v]. *) - For each such pair, [rr] is then unregistered from the congruence, and - [nrr] is registered instead. + val fold_leaves : (X.r -> 'a -> 'a) -> t -> 'a -> 'a - [f] is intended to perform a substitution operation on the type ['a], - merging the values associated with [rr] into the values associated with - [nrr]. *) + type domain + (** The type of domains. - val fold_parents : (X.r -> 'a -> 'a) -> t -> X.r -> 'a -> 'a - (** [fold_parents f t r acc] folds function [f] over all the registered terms - whose current representative contains [r] as a leaf. *) -end = struct - module SX = Shostak.SXH - module MX = Shostak.MXH + This is typically a mapping from variables to their own domain, but no + expectations is made upon the actual structure of that type. *) - type t = - { parents : SX.t MX.t - (** Map from leaves to terms that contain them as a leaf. + val propagate : t -> domain -> domain + (** [propagate c dom] propagates the constraints [c] in [d] and returns the + new domain. *) - [p] is in [parents(x)] => [x] is in [leaves(p)] *) - ; registered : SX.t - (** The set of terms we care about. If [x] is in [registered], - then [x] is also in [parents(y)] for each [y] in [leaves(x)]. *) - } +end - let empty = { parents = MX.empty ; registered = SX.empty } +module Constraints_Make(Constraint : Constraint) : sig + type t + (** The type of constraint sets. A constraint sets records a set of + constraints that applies to semantic values, and remembers which + constraints are associated with each semantic values. - let fold_parents f { parents; _ } r acc = - match MX.find r parents with - | deps -> SX.fold f deps acc - | exception Not_found -> acc + It is used to only propagate constraints involving semantic values whose + associated domain has changed. - let add r t = - if SX.mem r t.registered then - t - else - let parents = - List.fold_left (fun parents leaf -> - MX.update leaf (function - | Some deps -> Some (SX.add r deps) - | None -> Some (SX.singleton r) - ) parents - ) t.parents (X.leaves r) - in - { parents ; registered = SX.add r t.registered } - - let remove r t = - if SX.mem r t.registered then - let parents = - List.fold_left (fun parents leaf -> - MX.update leaf (function - | Some deps -> - let deps = SX.remove r deps in - if SX.is_empty deps then None else Some deps - | None -> - (* [r] is in registered, and [leaf] is in [leaves(r)], so - [r] must be in [parents(leaf)]. *) - assert false - ) parents - ) t.parents (X.leaves r) - in - { parents ; registered = SX.remove r t.registered } - else - invalid_arg "Congruence.remove" - - let subst rr nrr cgr f t = - match MX.find rr cgr.parents with - | rr_deps -> - let cgr = { cgr with parents = MX.remove rr cgr.parents } in - SX.fold (fun r (cgr, t) -> - let r' = X.subst rr nrr r in - (* [r] contains [rr] as a leaf by definition *) - assert (not (X.equal r r')); - - (* Update the other leaves *) - let parents = - List.fold_left (fun parents other_leaf -> - if X.equal other_leaf rr then - parents - else - MX.update other_leaf (function - | Some deps -> - let deps = SX.remove r deps in - if SX.is_empty deps then None else Some deps - | None -> assert false - ) parents - ) cgr.parents (X.leaves r) - in + The constraint sets are expected to keep track of *class representatives*, + i.e. normal forms wrt the `Uf` module, in which case we say the + constraint set is *normalized*. Use `subst` to ensure normalization. *) + + val pp : t Fmt.t + (** Pretty-printer for constraint sets. *) + + val empty : t + (** Returns an empty constraint set. *) + + val subst : Explanation.t -> X.r -> X.r -> t -> t + (** [subst ex p v cs] replaces all the instances of [p] with [v] in the + constraints. - (* It is no longer here, but it could be added back later -- let's not - skip it! *) - let registered = SX.remove r cgr.registered in + Use this to ensure that the representation is always normalized. - (* Add the new representative to the congruence if needed *) - let cgr = add r' { parents ; registered } in + The explanation [ex] justifies the equality [p = v]. *) - (* Propagate the substitution *) - cgr, f r r' t - ) rr_deps (cgr, t) - | exception Not_found -> cgr, t + val add : t -> Constraint.t -> t + (** [add c cs] adds the constraint [c] to [cs]. *) + + val propagate_fresh : t -> Constraint.domain -> t * Constraint.domain + (** [propagate_fresh cs acc] propagates the fresh constraints and returns the + new domain, as well as a copy of the constraint set with no fresh + constraints. + + Fresh constraints are constraints that were never propagated yet. *) + + val fold_r : (X.r -> 'a -> 'a) -> t -> 'a -> 'a + (** [fold_r f cs acc] folds [f] over any representative [r] that is currently + associated with a constraint (i.e. at least one constraint currently + applies to [r]). *) + + val propagate : t -> X.r -> Constraint.domain -> Constraint.domain + (** [propagate cs r dom] propagates the constraints associated with [r] in the + constraint set [cs] and returns the new domain map after propagation. *) +end = struct + module IM = Util.MI + module MX = Shostak.MXH + + module CS = Set.Make(Constraint) + + type t = { + cs_set : CS.t ; + (*** All the constraints currently active *) + cs_map : CS.t MX.t ; + (*** Mapping from semantic values to the constraints that involves them *) + fresh : CS.t ; + (*** Fresh constraints that have never been propagated *) + } + + let pp ppf { cs_set; cs_map = _ ; fresh = _ } = + Fmt.( + braces @@ hvbox @@ + iter ~sep:semi CS.iter @@ + box ~indent:2 @@ braces @@ + Constraint.pp + ) ppf cs_set + + let empty = + { cs_set = CS.empty + ; cs_map = MX.empty + ; fresh = CS.empty } + + let cs_add cs r cs_map = + MX.update r (function + | Some css -> Some (CS.add cs css) + | None -> Some (CS.singleton cs) + ) cs_map + + let cs_remove cs r cs_map = + MX.update r (function + | Some css -> + let css = CS.remove cs css in + if CS.is_empty css then None else Some css + | None -> + (* Can happen if the same argument is repeated *) + None + ) cs_map + + let subst ex rr nrr bcs = + match MX.find rr bcs.cs_map with + | ids -> + let cs_map, cs_set, fresh = + CS.fold (fun cs (cs_map, cs_set, fresh) -> + let fresh = CS.remove cs fresh in + let cs_set = CS.remove cs cs_set in + let cs_map = Constraint.fold_leaves (cs_remove cs) cs cs_map in + let cs' = Constraint.subst ex rr nrr cs in + if CS.mem cs' cs_set then + cs_map, cs_set, fresh + else + let cs_set = CS.add cs' cs_set in + let cs_map = Constraint.fold_leaves (cs_add cs') cs' cs_map in + (cs_map, cs_set, CS.add cs' fresh) + ) ids (bcs.cs_map, bcs.cs_set, bcs.fresh) + in + assert (not (MX.mem rr cs_map)); + { cs_set ; cs_map ; fresh } + | exception Not_found -> bcs + + let add bcs c = + if CS.mem c bcs.cs_set then + bcs + else + let cs_set = CS.add c bcs.cs_set in + let cs_map = Constraint.fold_leaves (cs_add c) c bcs.cs_map in + let fresh = CS.add c bcs.fresh in + { cs_set ; cs_map ; fresh } + + let fold_r f bcs acc = + MX.fold (fun r _ acc -> f r acc) bcs.cs_map acc + + let propagate bcs r dom = + match MX.find r bcs.cs_map with + | cs -> CS.fold Constraint.propagate cs dom + | exception Not_found -> dom + + let propagate_fresh bcs dom = + let dom = CS.fold Constraint.propagate bcs.fresh dom in + { bcs with fresh = CS.empty }, dom end