Skip to content

Commit

Permalink
Fixes cached solver (Closes #203)
Browse files Browse the repository at this point in the history
  • Loading branch information
filipeom committed Oct 16, 2024
1 parent 613bee2 commit be060e6
Showing 1 changed file with 68 additions and 16 deletions.
84 changes: 68 additions & 16 deletions src/solvers/solver.ml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ module Base (M : Mappings_intf.S) = struct

let get_value (solver : M.solver) (e : Expr.t) : Expr.t =
match M.Solver.model solver with
| Some m -> Expr.make @@ Val (M.value m e)
| Some m -> Expr.value @@ M.value m e
| None ->
Fmt.failwith "get_value: Trying to get a value from an unsat solver"

Expand All @@ -57,7 +57,10 @@ module Base (M : Mappings_intf.S) = struct
M.values_of_model ?symbols model
end

module Make_batch (Mappings : Mappings_intf.S) = struct
module Incremental (M : Mappings_intf.S) : Solver_intf.S =
Base [@inlined hint] (M)

module Batch (Mappings : Mappings.S) = struct
include Base (Mappings)

type solver = Mappings.solver
Expand All @@ -71,13 +74,10 @@ module Make_batch (Mappings : Mappings_intf.S) = struct
let pp_statistics fmt s = pp_statistics fmt s.solver

let create ?params ?logic () =
{ solver = Mappings.Solver.make ?params ?logic ()
; top = []
; stack = Stack.create ()
}
{ solver = create ?params ?logic (); top = []; stack = Stack.create () }

let clone ({ solver; top; stack } : t) : t =
{ solver; top; stack = Stack.copy stack }
{ solver = clone solver; top; stack = Stack.copy stack }

let push ({ top; stack; solver } : t) : unit =
Mappings.Solver.push solver;
Expand Down Expand Up @@ -118,22 +118,74 @@ module Make_batch (Mappings : Mappings_intf.S) = struct
let interrupt { solver; _ } = interrupt solver
end

(* TODO: Our base solver can be incrmental itself? *)
module Batch (M : Mappings_intf.S) : Solver_intf.S = Make_batch (M)

module Cached (M : Mappings_intf.S) = struct
include Make_batch (M)
module Cached (Mappings : Mappings.S) = struct
include Base (Mappings)
module Cache = Cache.Strong

let cache = Cache.create 256

type solver = Mappings.solver

type t =
{ solver : solver
; mutable top : Expr.Set.t
; stack : Expr.Set.t Stack.t
}

let pp_statistics fmt s = pp_statistics fmt s.solver

let create ?params ?logic () =
{ solver = create ?params ?logic ()
; top = Expr.Set.empty
; stack = Stack.create ()
}

let clone ({ solver; top; stack } : t) : t =
{ solver = clone solver; top; stack = Stack.copy stack }

let push ({ top; stack; solver } : t) : unit =
Mappings.Solver.push solver;
Stack.push top stack

let rec pop (s : t) (lvl : int) : unit =
assert (lvl <= Stack.length s.stack);
if lvl <= 0 then ()
else begin
Mappings.Solver.pop s.solver 1;
s.top <- Stack.pop s.stack;
pop s (lvl - 1)
end

let reset (s : t) =
Mappings.Solver.reset s.solver;
Stack.clear s.stack;
s.top <- Expr.Set.empty

let add (s : t) (es : Expr.t list) : unit =
s.top <- Expr.Set.(union (of_list es) s.top)

let add_set s es = s.top <- Expr.Set.union es s.top

let get_assertions (s : t) : Expr.t list = Expr.Set.to_list s.top [@@inline]

let get_statistics (s : t) : Statistics.t = get_statistics s.solver

let check_set s es =
match Cache.find_opt cache es with
let assert_ = Expr.Set.union es s.top in
match Cache.find_opt cache assert_ with
| Some res -> res
| None ->
let result = check_set s es in
let result = check_set s.solver assert_ in
Cache.add cache es result;
result
end

module Incremental (M : Mappings_intf.S) : Solver_intf.S = Base (M)
let check (s : t) (es : Expr.t list) : satisfiability =
check_set s (Expr.Set.of_list es)

let get_value (solver : t) (e : Expr.t) : Expr.t = get_value solver.solver e

let model ?(symbols : Symbol.t list option) (s : t) : Model.t option =
model ?symbols s.solver

let interrupt { solver; _ } = interrupt solver
end

0 comments on commit be060e6

Please sign in to comment.