diff --git a/src/solvers/solver.ml b/src/solvers/solver.ml index 113e48cd..40ebcde2 100644 --- a/src/solvers/solver.ml +++ b/src/solvers/solver.ml @@ -48,7 +48,7 @@ module Base (M : Mappings_intf.S) = struct let get_value (solver : M.solver) (e : Expr.t) : Expr.t = match M.Solver.model solver with - | Some m -> Expr.make @@ Val (M.value m e) + | Some m -> Expr.value @@ M.value m e | None -> Fmt.failwith "get_value: Trying to get a value from an unsat solver" @@ -57,7 +57,10 @@ module Base (M : Mappings_intf.S) = struct M.values_of_model ?symbols model end -module Make_batch (Mappings : Mappings_intf.S) = struct +module Incremental (M : Mappings_intf.S) : Solver_intf.S = + Base [@inlined hint] (M) + +module Batch (Mappings : Mappings.S) = struct include Base (Mappings) type solver = Mappings.solver @@ -71,13 +74,10 @@ module Make_batch (Mappings : Mappings_intf.S) = struct let pp_statistics fmt s = pp_statistics fmt s.solver let create ?params ?logic () = - { solver = Mappings.Solver.make ?params ?logic () - ; top = [] - ; stack = Stack.create () - } + { solver = create ?params ?logic (); top = []; stack = Stack.create () } let clone ({ solver; top; stack } : t) : t = - { solver; top; stack = Stack.copy stack } + { solver = clone solver; top; stack = Stack.copy stack } let push ({ top; stack; solver } : t) : unit = Mappings.Solver.push solver; @@ -118,22 +118,74 @@ module Make_batch (Mappings : Mappings_intf.S) = struct let interrupt { solver; _ } = interrupt solver end -(* TODO: Our base solver can be incrmental itself? *) -module Batch (M : Mappings_intf.S) : Solver_intf.S = Make_batch (M) - -module Cached (M : Mappings_intf.S) = struct - include Make_batch (M) +module Cached (Mappings : Mappings.S) = struct + include Base (Mappings) module Cache = Cache.Strong let cache = Cache.create 256 + type solver = Mappings.solver + + type t = + { solver : solver + ; mutable top : Expr.Set.t + ; stack : Expr.Set.t Stack.t + } + + let pp_statistics fmt s = pp_statistics fmt s.solver + + let create ?params ?logic () = + { solver = create ?params ?logic () + ; top = Expr.Set.empty + ; stack = Stack.create () + } + + let clone ({ solver; top; stack } : t) : t = + { solver = clone solver; top; stack = Stack.copy stack } + + let push ({ top; stack; solver } : t) : unit = + Mappings.Solver.push solver; + Stack.push top stack + + let rec pop (s : t) (lvl : int) : unit = + assert (lvl <= Stack.length s.stack); + if lvl <= 0 then () + else begin + Mappings.Solver.pop s.solver 1; + s.top <- Stack.pop s.stack; + pop s (lvl - 1) + end + + let reset (s : t) = + Mappings.Solver.reset s.solver; + Stack.clear s.stack; + s.top <- Expr.Set.empty + + let add (s : t) (es : Expr.t list) : unit = + s.top <- Expr.Set.(union (of_list es) s.top) + + let add_set s es = s.top <- Expr.Set.union es s.top + + let get_assertions (s : t) : Expr.t list = Expr.Set.to_list s.top [@@inline] + + let get_statistics (s : t) : Statistics.t = get_statistics s.solver + let check_set s es = - match Cache.find_opt cache es with + let assert_ = Expr.Set.union es s.top in + match Cache.find_opt cache assert_ with | Some res -> res | None -> - let result = check_set s es in + let result = check_set s.solver assert_ in Cache.add cache es result; result -end -module Incremental (M : Mappings_intf.S) : Solver_intf.S = Base (M) + let check (s : t) (es : Expr.t list) : satisfiability = + check_set s (Expr.Set.of_list es) + + let get_value (solver : t) (e : Expr.t) : Expr.t = get_value solver.solver e + + let model ?(symbols : Symbol.t list option) (s : t) : Model.t option = + model ?symbols s.solver + + let interrupt { solver; _ } = interrupt solver +end