diff --git a/src/cvc5_mappings.default.ml b/src/cvc5_mappings.default.ml index 4a53a972..a7c85d69 100644 --- a/src/cvc5_mappings.default.ml +++ b/src/cvc5_mappings.default.ml @@ -431,14 +431,19 @@ module Fresh_cvc5 () = struct end module Solver = struct + let set_param (type a) slv (param : a Params.param) (v : a) : unit = + match param with + | Timeout -> Solver.set_option slv "tlimit" (string_of_int v) + | Model -> Solver.set_option slv "produce-models" (string_of_bool v) + | Unsat_core -> + Solver.set_option slv "produce-unsat-core" (string_of_bool v) + | Ematching -> Solver.set_option slv "e-matching" (string_of_bool v) + | Parallel | Num_threads -> () + let set_params slv params = - Solver.set_option slv "e-matching" - (string_of_bool @@ Params.get params Ematching); - Solver.set_option slv "tlimit" (string_of_int @@ Params.get params Timeout); - Solver.set_option slv "produce-models" - (string_of_bool @@ Params.get params Model); - Solver.set_option slv "produce-unsat-cores" - (string_of_bool @@ Params.get params Unsat_core) + List.iter + (fun (Params.P (p, v)) -> set_param slv p v) + (Params.to_list params) let make ?params ?logic () = let logic = Option.map (fun l -> Fmt.str "%a" Ty.pp_logic l) logic in diff --git a/src/solvers/params.ml b/src/solvers/params.ml index 30283f5b..da4a6749 100644 --- a/src/solvers/params.ml +++ b/src/solvers/params.ml @@ -24,14 +24,29 @@ type _ param = | Parallel : bool param | Num_threads : int param -type t = - { timeout : int - ; model : bool - ; unsat_core : bool - ; ematching : bool - ; parallel : bool - ; num_threads : int - } +let discr : type a. a param -> int = function + | Timeout -> 0 + | Model -> 1 + | Unsat_core -> 2 + | Ematching -> 3 + | Parallel -> 4 + | Num_threads -> 5 + +module Key = struct + type t = K : 'a param -> t + + let v v = K v + + let compare (K a) (K b) = compare (discr a) (discr b) +end + +module Pmap = Map.Make (Key) + +type param' = P : 'a param * 'a -> param' + +let p k v = P (k, v) + +type t = param' Pmap.t let default_timeout = 2147483647 @@ -43,7 +58,6 @@ let default_ematching = true let default_parallel = false -(* FIXME: Will this be problematic if only (Parallel, true) is specified? *) let default_num_threads = 1 let default_value (type a) (param : a param) : a = @@ -56,22 +70,14 @@ let default_value (type a) (param : a param) : a = | Num_threads -> default_num_threads let default () = - { timeout = default_timeout - ; model = default_model - ; unsat_core = default_unsat_core - ; ematching = default_ematching - ; parallel = default_parallel - ; num_threads = default_num_threads - } + Pmap.empty + |> Pmap.add (Key.v Timeout) (p Timeout default_timeout) + |> Pmap.add (Key.v Model) (p Model default_model) + |> Pmap.add (Key.v Unsat_core) (p Unsat_core default_unsat_core) + |> Pmap.add (Key.v Ematching) (p Ematching default_ematching) let set (type a) (params : t) (param : a param) (value : a) : t = - match param with - | Timeout -> { params with timeout = value } - | Model -> { params with model = value } - | Unsat_core -> { params with unsat_core = value } - | Ematching -> { params with ematching = value } - | Parallel -> { params with parallel = value } - | Num_threads -> { params with num_threads = value } + Pmap.add (Key.v param) (p param value) params let opt (type a) (params : t) (param : a param) (opt_value : a option) : t = Option.fold ~none:params ~some:(set params param) opt_value @@ -80,10 +86,14 @@ let ( $ ) (type a) (params : t) ((param, value) : a param * a) : t = set params param value let get (type a) (params : t) (param : a param) : a = - match param with - | Timeout -> params.timeout - | Model -> params.model - | Unsat_core -> params.unsat_core - | Ematching -> params.ematching - | Parallel -> params.parallel - | Num_threads -> params.num_threads + match (param, Pmap.find (Key.v param) params) with + | Timeout, P (Timeout, v) -> v + | Model, P (Model, v) -> v + | Unsat_core, P (Unsat_core, v) -> v + | Ematching, P (Ematching, v) -> v + | Parallel, P (Parallel, v) -> v + | Num_threads, P (Num_threads, v) -> v + | (Timeout | Model | Unsat_core | Ematching | Parallel | Num_threads), _ -> + assert false + +let to_list (type a) (params : t) = List.map snd @@ Pmap.bindings params diff --git a/src/solvers/params.mli b/src/solvers/params.mli index 2ad81765..2fca0951 100644 --- a/src/solvers/params.mli +++ b/src/solvers/params.mli @@ -16,6 +16,8 @@ (* along with this program. If not, see . *) (***************************************************************************) +type t + type _ param = | Timeout : int param (** Specifies a timeout in miliseconds for each [check] call *) @@ -26,7 +28,7 @@ type _ param = | Num_threads : int param (** Speficied the maximum number of threads to use in parallel mode *) -type t +type param' = P : 'a param * 'a -> param' val default_value : 'a param -> 'a @@ -43,3 +45,5 @@ val opt : t -> 'a param -> 'a option -> t (** [get params p] fetches the current value for parameter [p] *) val get : t -> 'a param -> 'a + +val to_list : t -> param' list diff --git a/src/z3_mappings.default.ml b/src/z3_mappings.default.ml index 2cddca64..4fd2094e 100644 --- a/src/z3_mappings.default.ml +++ b/src/z3_mappings.default.ml @@ -438,19 +438,19 @@ module M = struct in List.fold_left add_entry Statistics.Map.empty statistics + let set_param (type a) (param : a Params.param) (v : a) : unit = + match param with + | Timeout -> Z3.Params.update_param_value ctx "timeout" (string_of_int v) + | Model -> Z3.Params.update_param_value ctx "model" (string_of_bool v) + | Unsat_core -> + Z3.Params.update_param_value ctx "unsat_core" (string_of_bool v) + | Ematching -> Z3.set_global_param "smt.ematching" (string_of_bool v) + | Parallel -> Z3.set_global_param "parallel.enable" (string_of_bool v) + | Num_threads -> + Z3.set_global_param "parallel.threads.max" (string_of_int v) + let set_params (params : Params.t) = - Z3.set_global_param "smt.ematching" - (string_of_bool @@ Params.get params Ematching); - Z3.set_global_param "parallel.enable" - (string_of_bool @@ Params.get params Parallel); - Z3.set_global_param "parallel.threads.max" - (string_of_int @@ Params.get params Num_threads); - Z3.Params.update_param_value ctx "timeout" - (string_of_int @@ Params.get params Timeout); - Z3.Params.update_param_value ctx "model" - (string_of_bool @@ Params.get params Model); - Z3.Params.update_param_value ctx "unsat_core" - (string_of_bool @@ Params.get params Unsat_core) + List.iter (fun (Params.P (p, v)) -> set_param p v) (Params.to_list params) module Solver = struct (* TODO: parameters *) diff --git a/test/solver/test_solver_params.ml b/test/solver/test_solver_params.ml index 045bf9b8..668de040 100644 --- a/test/solver/test_solver_params.ml +++ b/test/solver/test_solver_params.ml @@ -12,8 +12,10 @@ module Make (M : Mappings_intf.S) = struct let () = let params = - default () $ (Timeout, 900) $ (Unsat_core, true) $ (Model, false) - $ (Ematching, false) + default () $ (Timeout, 900) $ (Model, false) $ (Unsat_core, true) + $ (Ematching, false) $ (Parallel, true) $ (Num_threads, 1) in - ignore (Solver.create ~params ()) + assert (Params.get params Unsat_core); + let _ : Solver.t = Solver.create ~params () in + () end