Skip to content

Commit

Permalink
Exporing rounding mode to SMT
Browse files Browse the repository at this point in the history
  • Loading branch information
Stevendeo committed Oct 13, 2023
1 parent 21479bf commit eb8bbbf
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 83 deletions.
12 changes: 10 additions & 2 deletions src/lib/frontend/d_cnf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ let fpa_builtins =
| _ -> assert false
in
let fpa_rounding_mode, rounding_modes, add_rounding_modes =
builtin_enum Fpa_rounding.fpa_rounding_mode
let module FPAU : Fpa_rounding.S = (val (Fpa_rounding.fpa_rounding_utils ())) in
builtin_enum FPAU.fpa_rounding_mode
in
let float_cst =
let ty = DT.(arrow [int; int; fpa_rounding_mode; real] real) in
Expand Down Expand Up @@ -429,11 +430,18 @@ let fpa_builtins =
| Builtin _ -> `Not_found
end

(** Concatenation of builtins handlers. *)
(* let (++) bt1 bt2 =
* fun a b ->
* match bt1 a b with
* | `Not_found -> bt2 a b
* | res -> res *)

let builtins =
fun _st (lang : Typer.lang) ->
match lang with
| `Logic Alt_ergo -> fpa_builtins
| `Logic (Smtlib2 _) -> bv_builtins
| `Logic (Smtlib2 _) -> (* fpa_builtins ++ *) bv_builtins
| _ -> fun _ _ -> `Not_found

(** Translates dolmen locs to Alt-Ergo's locs *)
Expand Down
17 changes: 12 additions & 5 deletions src/lib/frontend/typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,14 @@ module Env = struct
| _ -> assert false

let add_fpa_builtins env =
let module FPAU : Fpa_rounding.S =
(val (Fpa_rounding.fpa_rounding_utils ())) in
let (->.) args result = { args; result } in
let int n = {
c = { tt_desc = TTconst (Tint n); tt_ty = Ty.Tint} ;
annot = new_id () ;
} in
let rm = Fpa_rounding.fpa_rounding_mode in
let rm = FPAU.fpa_rounding_mode in
let mode m =
let h = find_builtin_cstr rm m in
{
Expand All @@ -298,10 +300,12 @@ module Env = struct
let float prec exp mode x =
TTapp (Symbols.Op Float, [prec; exp; mode; x])
in
let nte = FPAU.string_of_rounding_mode NearestTiesToEven in
let tname = FPAU.fpa_rounding_mode_type_name in
let float32 = float (int "24") (int "149") in
let float32d = float32 (mode "NearestTiesToEven") in
let float32d = float32 (mode nte) in
let float64 = float (int "53") (int "1074") in
let float64d = float64 (mode "NearestTiesToEven") in
let float64d = float64 (mode nte) in
let op n op profile =
MString.add n @@ `Term (Symbols.Op op, profile, Other)
in
Expand All @@ -312,8 +316,11 @@ module Env = struct
let any = Ty.fresh_tvar in
let env = {
env with
types = Types.add_builtin env.types "fpa_rounding_mode" rm ;
builtins = add_builtin_enum Fpa_rounding.fpa_rounding_mode env.builtins;
types = Types.add_builtin env.types tname rm ;
builtins =
add_builtin_enum
FPAU.fpa_rounding_mode
env.builtins;
} in
let builtins =
env.builtins
Expand Down
6 changes: 4 additions & 2 deletions src/lib/structures/expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2811,6 +2811,8 @@ type const =
| RoundingMode of Fpa_rounding.rounding_mode

let const_view t =
let module FPAU : Fpa_rounding.S =
(val (Fpa_rounding.fpa_rounding_utils ())) in
match term_view t with
| { f = Int n; _ } ->
begin match Z.to_int n with
Expand All @@ -2819,8 +2821,8 @@ let const_view t =
Fmt.failwith "error when trying to convert %a to an int" Z.pp_print n
end
| { f = Op (Constr c); ty; _ }
when Ty.equal ty Fpa_rounding.fpa_rounding_mode ->
RoundingMode (Fpa_rounding.rounding_mode_of_hs c)
when Ty.equal ty FPAU.fpa_rounding_mode ->
RoundingMode (FPAU.rounding_mode_of_hs c)
| _ -> Fmt.failwith "unsupported constant: %a" print t

let int_view t =
Expand Down
152 changes: 93 additions & 59 deletions src/lib/structures/fpa_rounding.ml
Original file line number Diff line number Diff line change
Expand Up @@ -33,64 +33,101 @@ module Hs = Hstring
module Q = Numbers.Q
module Z = Numbers.Z

(** The five standard rounding modes of the SMTLIB.
Note that the SMTLIB defines these rounding modes to be the only
possible modes. *)
type rounding_mode =
(* five standard/why3 fpa rounding modes *)
| NearestTiesToEven
(*ne in Gappa: to nearest, tie breaking to even mantissas*)
| ToZero (* zr in Gappa: toward zero *)
| Up (* up in Gappa: toward plus infinity *)
| Down (* dn in Gappa: toward minus infinity *)
| NearestTiesToAway (* na : to nearest, tie breaking away from zero *)

(* additional Gappa rounding modes *)
| Aw (* aw in Gappa: away from zero **)
| Od (* od in Gappa: to odd mantissas *)
| No (* no in Gappa: to nearest, tie breaking to odd mantissas *)
| Nz (* nz in Gappa: to nearest, tie breaking toward zero *)
| Nd (* nd in Gappa: to nearest, tie breaking toward minus infinity *)
| Nu (* nu in Gappa: to nearest, tie breaking toward plus infinity *)

let pp_rounding_mode ppf m =
Format.pp_print_string ppf @@
match m with
| NearestTiesToEven -> "NearestTiesToEven"
| ToZero -> "ToZero"
| Up -> "Up"
| Down -> "Down"
| NearestTiesToAway -> "NearestTiesToAway"
| Aw -> "Aw"
| Od -> "Od"
| No -> "No"
| Nz -> "Nz"
| Nd -> "Nd"
| Nu -> "Nu"

let fpa_rounding_mode, rounding_mode_of_hs =
let cstrs =
[ (* standards *)
NearestTiesToEven;
ToZero;
Up;
Down;
NearestTiesToAway;
(* non standards *)
Aw;
Od;
No;
Nz;
Nd;
Nu ]
in
let h_cstrs =
List.map (fun c -> Hs.make (Format.asprintf "%a" pp_rounding_mode c)) cstrs
in
let ty = Ty.Tsum (Hs.make "fpa_rounding_mode", h_cstrs) in
let table =
let table = Hashtbl.create 17 in
List.iter2 (Hashtbl.add table) h_cstrs cstrs;
table
in
ty, Hashtbl.find table
| ToZero
| Up
| Down
| NearestTiesToAway

let cstrs =
[
NearestTiesToEven;
ToZero;
Up;
Down;
NearestTiesToAway;
]

module type S = sig
val fpa_rounding_mode_type_name : string

val fpa_rounding_mode : Ty.t

val rounding_mode_of_hs : Hstring.t -> rounding_mode

val string_of_rounding_mode : rounding_mode -> string
end

module Make (I : sig
val name : string
val to_string : rounding_mode -> string
end) : S = struct

let fpa_rounding_mode_type_name = I.name

let string_of_rounding_mode = I.to_string

let fpa_rounding_mode, rounding_mode_of_hs =
let h_cstrs =
List.map (fun c -> Hs.make (I.to_string c)) cstrs
in
let ty = Ty.Tsum (Hs.make I.name, h_cstrs) in
let table =
let table = Hashtbl.create 5 in
List.iter2 (
fun key bnd ->
Hashtbl.add table key bnd
) h_cstrs cstrs;
table
in
ty,
(fun key -> match Hashtbl.find_opt table key with
| None ->
Fmt.failwith
"%a"
(fun fmt k ->
Format.pp_print_string fmt I.name;
Hashtbl.iter (fun key _ -> Format.fprintf fmt "%a --" Hstring.print key) table;
Hstring.print fmt k
)
key
| Some res -> res)
end

module AE : S =
Make (struct
let name = "fpa_rounding_mode"
let to_string =
function
| NearestTiesToEven -> "NearestTiesToEven"
| ToZero -> "ToZero"
| Up -> "Up"
| Down -> "Down"
| NearestTiesToAway -> "NearestTiesToAway"
end
)

module SMT2 : S =
Make (struct
let name = "RoundingMode"

let to_string =
function
| NearestTiesToEven -> "RNE"
| ToZero -> "RTZ"
| Up -> "RTP"
| Down -> "RTN"
| NearestTiesToAway -> "RNA"
end)

let fpa_rounding_utils () =
match Options.get_input_format () with
| None | Some Smtlib2 -> (module SMT2 : S)
| _ -> (module AE : S)

(** Helper functions **)

Expand Down Expand Up @@ -160,9 +197,6 @@ let round_big_int (mode : rounding_mode) y =
if Q.sign diff = 0 then z
else if Q.compare diff half < 0 then z else Z.add z (signed_one y)

| Aw | Od | No | Nz | Nd | Nu -> assert false


let to_mantissa_exp prec exp mode x =
let sign_x = Q.sign x in
assert ((sign_x = 0) == Q.equal x Q.zero);
Expand Down
33 changes: 18 additions & 15 deletions src/lib/structures/fpa_rounding.mli
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,28 @@
(**************************************************************************)

type rounding_mode =
(* five standard/why3 fpa rounding modes *)
| NearestTiesToEven
(*ne in Gappa: to nearest, tie breaking to even mantissas*)
| ToZero (* zr in Gappa: toward zero *)
| Up (* up in Gappa: toward plus infinity *)
| Down (* dn in Gappa: toward minus infinity *)
| NearestTiesToAway (* na : to nearest, tie breaking away from zero *)
| ToZero
| Up
| Down
| NearestTiesToAway

(* additional Gappa rounding modes *)
| Aw (* aw in Gappa: away from zero **)
| Od (* od in Gappa: to odd mantissas *)
| No (* no in Gappa: to nearest, tie breaking to odd mantissas *)
| Nz (* nz in Gappa: to nearest, tie breaking toward zero *)
| Nd (* nd in Gappa: to nearest, tie breaking toward minus infinity *)
| Nu (* nu in Gappa: to nearest, tie breaking toward plus infinity *)
module type S = sig
val fpa_rounding_mode_type_name : string

val fpa_rounding_mode : Ty.t
val fpa_rounding_mode : Ty.t

val rounding_mode_of_hs : Hstring.t -> rounding_mode
val rounding_mode_of_hs : Hstring.t -> rounding_mode

val string_of_rounding_mode : rounding_mode -> string
end

module AE : S
module SMT2 : S

(** Returns (module SMT2) if the input format is [None] or [Some Smtlib2],
otherwise returns [AE]. *)
val fpa_rounding_utils : unit -> (module S)

(** Integer part of binary logarithm for NON-ZERO POSITIVE number **)
val integer_log_2 : Numbers.Q.t -> int
Expand Down

0 comments on commit eb8bbbf

Please sign in to comment.