Skip to content

Commit

Permalink
Exporing rounding mode to SMT
Browse files Browse the repository at this point in the history
  • Loading branch information
Stevendeo committed Oct 12, 2023
1 parent 21479bf commit 2700cd7
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 70 deletions.
11 changes: 9 additions & 2 deletions src/lib/frontend/d_cnf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ let fpa_builtins =
| _ -> assert false
in
let fpa_rounding_mode, rounding_modes, add_rounding_modes =
builtin_enum Fpa_rounding.fpa_rounding_mode
builtin_enum (Fpa_rounding.fpa_rounding_mode ())
in
let float_cst =
let ty = DT.(arrow [int; int; fpa_rounding_mode; real] real) in
Expand Down Expand Up @@ -429,11 +429,18 @@ let fpa_builtins =
| Builtin _ -> `Not_found
end

(** Concatenation of builtins handlers. *)
(* let (++) bt1 bt2 =
* fun a b ->
* match bt1 a b with
* | `Not_found -> bt2 a b
* | res -> res *)

let builtins =
fun _st (lang : Typer.lang) ->
match lang with
| `Logic Alt_ergo -> fpa_builtins
| `Logic (Smtlib2 _) -> bv_builtins
| `Logic (Smtlib2 _) -> (* fpa_builtins ++ *) bv_builtins
| _ -> fun _ _ -> `Not_found

(** Translates dolmen locs to Alt-Ergo's locs *)
Expand Down
15 changes: 10 additions & 5 deletions src/lib/frontend/typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ module Env = struct
c = { tt_desc = TTconst (Tint n); tt_ty = Ty.Tint} ;
annot = new_id () ;
} in
let rm = Fpa_rounding.fpa_rounding_mode in
let rm = Fpa_rounding.fpa_rounding_mode () in
let mode m =
let h = find_builtin_cstr rm m in
{
Expand All @@ -298,10 +298,12 @@ module Env = struct
let float prec exp mode x =
TTapp (Symbols.Op Float, [prec; exp; mode; x])
in
let nte = Fpa_rounding.string_of_rounding_mode NearestTiesToEven in
let tname = Fpa_rounding.fpa_rounding_mode_type_name () in
let float32 = float (int "24") (int "149") in
let float32d = float32 (mode "NearestTiesToEven") in
let float32d = float32 (mode nte) in
let float64 = float (int "53") (int "1074") in
let float64d = float64 (mode "NearestTiesToEven") in
let float64d = float64 (mode nte) in
let op n op profile =
MString.add n @@ `Term (Symbols.Op op, profile, Other)
in
Expand All @@ -312,8 +314,11 @@ module Env = struct
let any = Ty.fresh_tvar in
let env = {
env with
types = Types.add_builtin env.types "fpa_rounding_mode" rm ;
builtins = add_builtin_enum Fpa_rounding.fpa_rounding_mode env.builtins;
types = Types.add_builtin env.types tname rm ;
builtins =
add_builtin_enum
(Fpa_rounding.fpa_rounding_mode ())
env.builtins;
} in
let builtins =
env.builtins
Expand Down
2 changes: 1 addition & 1 deletion src/lib/structures/expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2819,7 +2819,7 @@ let const_view t =
Fmt.failwith "error when trying to convert %a to an int" Z.pp_print n
end
| { f = Op (Constr c); ty; _ }
when Ty.equal ty Fpa_rounding.fpa_rounding_mode ->
when Ty.equal ty (Fpa_rounding.fpa_rounding_mode ()) ->
RoundingMode (Fpa_rounding.rounding_mode_of_hs c)
| _ -> Fmt.failwith "unsupported constant: %a" print t

Expand Down
133 changes: 85 additions & 48 deletions src/lib/structures/fpa_rounding.ml
Original file line number Diff line number Diff line change
Expand Up @@ -33,64 +33,104 @@ module Hs = Hstring
module Q = Numbers.Q
module Z = Numbers.Z

(** The five standard rounding modes of the SMTLIB.
Note that the SMTLIB defines these rounding modes to be the only
possible modes. *)
type rounding_mode =
(* five standard/why3 fpa rounding modes *)
| NearestTiesToEven
(*ne in Gappa: to nearest, tie breaking to even mantissas*)
| ToZero (* zr in Gappa: toward zero *)
| Up (* up in Gappa: toward plus infinity *)
| Down (* dn in Gappa: toward minus infinity *)
| NearestTiesToAway (* na : to nearest, tie breaking away from zero *)

(* additional Gappa rounding modes *)
| Aw (* aw in Gappa: away from zero **)
| Od (* od in Gappa: to odd mantissas *)
| No (* no in Gappa: to nearest, tie breaking to odd mantissas *)
| Nz (* nz in Gappa: to nearest, tie breaking toward zero *)
| Nd (* nd in Gappa: to nearest, tie breaking toward minus infinity *)
| Nu (* nu in Gappa: to nearest, tie breaking toward plus infinity *)

let pp_rounding_mode ppf m =
Format.pp_print_string ppf @@
| ToZero
| Up
| Down
| NearestTiesToAway

let string_of_rounding_mode_ae m =
match m with
| NearestTiesToEven -> "NearestTiesToEven"
| ToZero -> "ToZero"
| Up -> "Up"
| Down -> "Down"
| NearestTiesToAway -> "NearestTiesToAway"
| Aw -> "Aw"
| Od -> "Od"
| No -> "No"
| Nz -> "Nz"
| Nd -> "Nd"
| Nu -> "Nu"

let fpa_rounding_mode, rounding_mode_of_hs =
let cstrs =
[ (* standards *)
NearestTiesToEven;
ToZero;
Up;
Down;
NearestTiesToAway;
(* non standards *)
Aw;
Od;
No;
Nz;
Nd;
Nu ]
in

let string_of_rounding_mode_smt2 m =
match m with
| NearestTiesToEven -> "RNE"
| ToZero -> "RTZ"
| Up -> "RTP"
| Down -> "RTN"
| NearestTiesToAway -> "RNA"

let cstrs =
[
NearestTiesToEven;
ToZero;
Up;
Down;
NearestTiesToAway;
]

let if_smt ~then_ ~else_ =
match Options.get_input_format () with
| None | Some Smtlib2 -> then_
| _ -> else_

let typ_and_rounding pp tname =
let h_cstrs =
List.map (fun c -> Hs.make (Format.asprintf "%a" pp_rounding_mode c)) cstrs
List.map (fun c -> Hs.make (pp c)) cstrs
in
let ty = Ty.Tsum (Hs.make "fpa_rounding_mode", h_cstrs) in
let ty = Ty.Tsum (Hs.make tname, h_cstrs) in
let table =
let table = Hashtbl.create 17 in
List.iter2 (Hashtbl.add table) h_cstrs cstrs;
let table = Hashtbl.create 5 in
List.iter2 (
fun key bnd ->
Hashtbl.add table key bnd
) h_cstrs cstrs;
table
in
ty, Hashtbl.find table
ty,
(fun key -> match Hashtbl.find_opt table key with
| None ->
Fmt.failwith
"%a"
(fun fmt k ->
Format.pp_print_string fmt tname;
Hashtbl.iter (fun key _ -> Format.fprintf fmt "%a --" Hstring.print key) table;
Hstring.print fmt k
)
key
| Some res -> res)

let ae_type_name = "fpa_rounding_mode"
let smt2_type_name = "RoundingMode"

let fpa_rounding_mode_ae, rounding_mode_of_hs_ae =
typ_and_rounding string_of_rounding_mode_ae ae_type_name

let fpa_rounding_mode_smt, rounding_mode_of_hs_smt =
typ_and_rounding string_of_rounding_mode_smt2 smt2_type_name

let rounding_mode_of_hs_smt hs =
Format.printf "rounding_mode_smt2@.";
rounding_mode_of_hs_smt hs

let rounding_mode_of_hs_ae hs =
Format.printf "rounding_mode_ae@.";
rounding_mode_of_hs_ae hs

let fpa_rounding_mode_type_name () =
if_smt ~then_:smt2_type_name ~else_:ae_type_name

let fpa_rounding_mode () =
if_smt ~then_:fpa_rounding_mode_smt ~else_:fpa_rounding_mode_ae

let rounding_mode_of_hs hs =
Format.pp_print_bool Format.std_formatter (if_smt ~then_:true ~else_:false);
if_smt ~then_:(rounding_mode_of_hs_smt) ~else_:(rounding_mode_of_hs_ae) hs


let string_of_rounding_mode m =
if_smt
~then_:(string_of_rounding_mode_smt2 m)
~else_:(string_of_rounding_mode_ae m)

(** Helper functions **)

Expand Down Expand Up @@ -160,9 +200,6 @@ let round_big_int (mode : rounding_mode) y =
if Q.sign diff = 0 then z
else if Q.compare diff half < 0 then z else Z.add z (signed_one y)

| Aw | Od | No | Nz | Nd | Nu -> assert false


let to_mantissa_exp prec exp mode x =
let sign_x = Q.sign x in
assert ((sign_x = 0) == Q.equal x Q.zero);
Expand Down
22 changes: 8 additions & 14 deletions src/lib/structures/fpa_rounding.mli
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,20 @@
(**************************************************************************)

type rounding_mode =
(* five standard/why3 fpa rounding modes *)
| NearestTiesToEven
(*ne in Gappa: to nearest, tie breaking to even mantissas*)
| ToZero (* zr in Gappa: toward zero *)
| Up (* up in Gappa: toward plus infinity *)
| Down (* dn in Gappa: toward minus infinity *)
| NearestTiesToAway (* na : to nearest, tie breaking away from zero *)
| ToZero
| Up
| Down
| NearestTiesToAway

(* additional Gappa rounding modes *)
| Aw (* aw in Gappa: away from zero **)
| Od (* od in Gappa: to odd mantissas *)
| No (* no in Gappa: to nearest, tie breaking to odd mantissas *)
| Nz (* nz in Gappa: to nearest, tie breaking toward zero *)
| Nd (* nd in Gappa: to nearest, tie breaking toward minus infinity *)
| Nu (* nu in Gappa: to nearest, tie breaking toward plus infinity *)
val fpa_rounding_mode_type_name : unit -> string

val fpa_rounding_mode : Ty.t
val fpa_rounding_mode : unit -> Ty.t

val rounding_mode_of_hs : Hstring.t -> rounding_mode

val string_of_rounding_mode : rounding_mode -> string

(** Integer part of binary logarithm for NON-ZERO POSITIVE number **)
val integer_log_2 : Numbers.Q.t -> int

Expand Down

0 comments on commit 2700cd7

Please sign in to comment.