Skip to content

Commit

Permalink
Use a map to store solver params
Browse files Browse the repository at this point in the history
  • Loading branch information
filipeom committed Oct 14, 2024
1 parent ba7f4a7 commit a1addf6
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 53 deletions.
19 changes: 12 additions & 7 deletions src/cvc5_mappings.default.ml
Original file line number Diff line number Diff line change
Expand Up @@ -431,14 +431,19 @@ module Fresh_cvc5 () = struct
end

module Solver = struct
let set_param (type a) slv (param : a Params.param) (v : a) : unit =
match param with
| Timeout -> Solver.set_option slv "tlimit" (string_of_int v)
| Model -> Solver.set_option slv "produce-models" (string_of_bool v)
| Unsat_core ->
Solver.set_option slv "produce-unsat-core" (string_of_bool v)
| Ematching -> Solver.set_option slv "e-matching" (string_of_bool v)
| Parallel | Num_threads -> ()

let set_params slv params =
Solver.set_option slv "e-matching"
(string_of_bool @@ Params.get params Ematching);
Solver.set_option slv "tlimit" (string_of_int @@ Params.get params Timeout);
Solver.set_option slv "produce-models"
(string_of_bool @@ Params.get params Model);
Solver.set_option slv "produce-unsat-cores"
(string_of_bool @@ Params.get params Unsat_core)
List.iter
(fun (Params.P (p, v)) -> set_param slv p v)
(Params.to_list params)

let make ?params ?logic () =
let logic = Option.map (fun l -> Fmt.str "%a" Ty.pp_logic l) logic in
Expand Down
70 changes: 40 additions & 30 deletions src/solvers/params.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,29 @@ type _ param =
| Parallel : bool param
| Num_threads : int param

type t =
{ timeout : int
; model : bool
; unsat_core : bool
; ematching : bool
; parallel : bool
; num_threads : int
}
let discr : type a. a param -> int = function
| Timeout -> 0
| Model -> 1
| Unsat_core -> 2
| Ematching -> 3
| Parallel -> 4
| Num_threads -> 5

module Key = struct
type t = K : 'a param -> t

let v v = K v

let compare (K a) (K b) = compare (discr a) (discr b)
end

module Pmap = Map.Make (Key)

type param' = P : 'a param * 'a -> param'

let p k v = P (k, v)

type t = param' Pmap.t

let default_timeout = 2147483647

Expand All @@ -43,7 +58,6 @@ let default_ematching = true

let default_parallel = false

(* FIXME: Will this be problematic if only (Parallel, true) is specified? *)
let default_num_threads = 1

let default_value (type a) (param : a param) : a =
Expand All @@ -56,22 +70,14 @@ let default_value (type a) (param : a param) : a =
| Num_threads -> default_num_threads

let default () =
{ timeout = default_timeout
; model = default_model
; unsat_core = default_unsat_core
; ematching = default_ematching
; parallel = default_parallel
; num_threads = default_num_threads
}
Pmap.empty
|> Pmap.add (Key.v Timeout) (p Timeout default_timeout)
|> Pmap.add (Key.v Model) (p Model default_model)
|> Pmap.add (Key.v Unsat_core) (p Unsat_core default_unsat_core)
|> Pmap.add (Key.v Ematching) (p Ematching default_ematching)

let set (type a) (params : t) (param : a param) (value : a) : t =
match param with
| Timeout -> { params with timeout = value }
| Model -> { params with model = value }
| Unsat_core -> { params with unsat_core = value }
| Ematching -> { params with ematching = value }
| Parallel -> { params with parallel = value }
| Num_threads -> { params with num_threads = value }
Pmap.add (Key.v param) (p param value) params

let opt (type a) (params : t) (param : a param) (opt_value : a option) : t =
Option.fold ~none:params ~some:(set params param) opt_value
Expand All @@ -80,10 +86,14 @@ let ( $ ) (type a) (params : t) ((param, value) : a param * a) : t =
set params param value

let get (type a) (params : t) (param : a param) : a =
match param with
| Timeout -> params.timeout
| Model -> params.model
| Unsat_core -> params.unsat_core
| Ematching -> params.ematching
| Parallel -> params.parallel
| Num_threads -> params.num_threads
match (param, Pmap.find (Key.v param) params) with
| Timeout, P (Timeout, v) -> v
| Model, P (Model, v) -> v
| Unsat_core, P (Unsat_core, v) -> v
| Ematching, P (Ematching, v) -> v
| Parallel, P (Parallel, v) -> v
| Num_threads, P (Num_threads, v) -> v
| (Timeout | Model | Unsat_core | Ematching | Parallel | Num_threads), _ ->
assert false

let to_list params = List.map snd @@ Pmap.bindings params
6 changes: 5 additions & 1 deletion src/solvers/params.mli
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
(* along with this program. If not, see <https://www.gnu.org/licenses/>. *)
(***************************************************************************)

type t

type _ param =
| Timeout : int param
(** Specifies a timeout in miliseconds for each [check] call *)
Expand All @@ -26,7 +28,7 @@ type _ param =
| Num_threads : int param
(** Speficied the maximum number of threads to use in parallel mode *)

type t
type param' = P : 'a param * 'a -> param'

val default_value : 'a param -> 'a

Expand All @@ -43,3 +45,5 @@ val opt : t -> 'a param -> 'a option -> t

(** [get params p] fetches the current value for parameter [p] *)
val get : t -> 'a param -> 'a

val to_list : t -> param' list
24 changes: 12 additions & 12 deletions src/z3_mappings.default.ml
Original file line number Diff line number Diff line change
Expand Up @@ -438,19 +438,19 @@ module M = struct
in
List.fold_left add_entry Statistics.Map.empty statistics

let set_param (type a) (param : a Params.param) (v : a) : unit =
match param with
| Timeout -> Z3.Params.update_param_value ctx "timeout" (string_of_int v)
| Model -> Z3.Params.update_param_value ctx "model" (string_of_bool v)
| Unsat_core ->
Z3.Params.update_param_value ctx "unsat_core" (string_of_bool v)
| Ematching -> Z3.set_global_param "smt.ematching" (string_of_bool v)
| Parallel -> Z3.set_global_param "parallel.enable" (string_of_bool v)
| Num_threads ->
Z3.set_global_param "parallel.threads.max" (string_of_int v)

let set_params (params : Params.t) =
Z3.set_global_param "smt.ematching"
(string_of_bool @@ Params.get params Ematching);
Z3.set_global_param "parallel.enable"
(string_of_bool @@ Params.get params Parallel);
Z3.set_global_param "parallel.threads.max"
(string_of_int @@ Params.get params Num_threads);
Z3.Params.update_param_value ctx "timeout"
(string_of_int @@ Params.get params Timeout);
Z3.Params.update_param_value ctx "model"
(string_of_bool @@ Params.get params Model);
Z3.Params.update_param_value ctx "unsat_core"
(string_of_bool @@ Params.get params Unsat_core)
List.iter (fun (Params.P (p, v)) -> set_param p v) (Params.to_list params)

module Solver = struct
(* TODO: parameters *)
Expand Down
8 changes: 5 additions & 3 deletions test/solver/test_solver_params.ml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ module Make (M : Mappings_intf.S) = struct

let () =
let params =
default () $ (Timeout, 900) $ (Unsat_core, true) $ (Model, false)
$ (Ematching, false)
default () $ (Timeout, 900) $ (Model, false) $ (Unsat_core, true)
$ (Ematching, false) $ (Parallel, true) $ (Num_threads, 1)
in
ignore (Solver.create ~params ())
assert (Params.get params Unsat_core);
let _ : Solver.t = Solver.create ~params () in
()
end

0 comments on commit a1addf6

Please sign in to comment.