Skip to content

Commit

Permalink
feat: exporting alt-ergo FPA built-in primitives (#876)
Browse files Browse the repository at this point in the history
* Exporing rounding mode to SMT

* Adding round as a non indexed primitive

* Indexed identifier

* Poetry

* Adding some tests

* Reverting Rounding Mode as index

* Not relying on input format

* Injecting AE type float rounding type into SMT rounding type

* Poetry

* Style

* Poetry

* More poetry

* Also translating on the native side

* Adding missing tests

* Rebase artifact

* Adding tests and some poetry
  • Loading branch information
Stevendeo authored Oct 30, 2023
1 parent 42ed410 commit 0e3fc36
Show file tree
Hide file tree
Showing 14 changed files with 434 additions and 148 deletions.
231 changes: 168 additions & 63 deletions src/lib/frontend/d_cnf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ end
(** Builtins *)
type _ DStd.Builtin.t +=
| Float
| AERound of int * int
(** Equivalent of Float for the SMT2 format. *)
| Integer_round
| Abs_real
| Sqrt_real
Expand All @@ -193,6 +195,49 @@ type _ DStd.Builtin.t +=
(* Internal use for semantic triggers -- do not expose outside of theories *)
| Not_theory_constant | Is_theory_constant | Linear_dependency

let builtin_term t = Dl.Typer.T.builtin_term t

let builtin_ty t = Dl.Typer.T.builtin_ty t

let ty name ty =
Id.Map.add { name = DStd.Name.simple name; ns = Sort } @@
fun env s ->
builtin_ty @@
Dolmen_type.Base.app0 (module Dl.Typer.T) env s ty

let builtin_enum = function
| Ty.Tsum (name, cstrs) as ty_ ->
let ty_cst =
DStd.Expr.Id.mk ~builtin:B.Base
(DStd.Path.global (Hstring.view name))
DStd.Expr.{ arity = 0; alias = No_alias }
in
let cstrs =
List.map (fun c -> DStd.Path.global (Hstring.view c), []) cstrs
in
let _, cstrs = DStd.Expr.Term.define_adt ty_cst [] cstrs in
let dty = DT.apply ty_cst [] in
let add_cstrs map =
List.fold_left (fun map ((c : DE.term_cst), _) ->
let name = get_basename c.path in
Id.Map.add { name = DStd.Name.simple name; ns = Term } (fun env _ ->
builtin_term @@
Dolmen_type.Base.term_app_cst
(module Dl.Typer.T) env c) map)
map cstrs
in
Cache.store_ty (DE.Ty.Const.hash ty_cst) ty_;
dty,
cstrs,
fun map ->
map
|> ty (Hstring.view name) dty
|> add_cstrs
| _ -> assert false

let fpa_rounding_mode, rounding_modes, add_rounding_modes =
builtin_enum Fpa_rounding.fpa_rounding_mode

module Const = struct
open DE

Expand All @@ -207,6 +252,15 @@ module Const = struct
let name = "int2bv" in
Id.mk ~name ~builtin:(Int2BV n)
(DStd.Path.global name) Ty.(arrow [int] (bitv n)))

let smt_round =
with_cache (fun (n, m) ->
let name = "ae.round" in
Id.mk
~name
~builtin:(AERound (n, m))
(DStd.Path.global name)
Ty.(arrow [fpa_rounding_mode; real] real))
end

let bv2nat t =
Expand All @@ -220,6 +274,9 @@ let bv2nat t =
let int2bv n t =
DE.Term.apply_cst (Const.int2bv n) [] [t]

let smt_round n m rm t =
DE.Term.apply_cst (Const.smt_round (n, m)) [] [rm; t]

let bv_builtins env s =
let term_app1 f =
Dl.Typer.T.builtin_term @@
Expand All @@ -241,54 +298,49 @@ let bv_builtins env s =
end
| _ -> `Not_found

let fpa_builtins =
(** Takes a dolmen identifier [id] and injects it in Alt-Ergo's registered
identifiers.
It transforms "fpa_rounding_mode", the Alt-Ergo builtin type into the SMT2
rounding type "RoundingMode". Also injects each constructor into their SMT2
equivalent *)
let inject_ae_to_smt2 id =
match id with
| Id.{name = Simple n; _} ->
begin
if String.equal n Fpa_rounding.fpa_rounding_mode_ae_type_name then
(* Injecting the type name as the SMT2 Type name. *)
let name =
Dolmen_std.Name.simple Fpa_rounding.fpa_rounding_mode_type_name
in
{id with name}
else
match Fpa_rounding.rounding_mode_of_ae_hs (Hstring.make n) with
| rm ->
let name =
Dolmen_std.Name.simple (Fpa_rounding.string_of_rounding_mode rm)
in
{id with name}
| exception (Failure _) ->
id
end
| id -> id

let ae_fpa_builtins =
let (->.) args ret = (args, ret) in
let builtin_term t = Dl.Typer.T.builtin_term t in
let builtin_ty t = Dl.Typer.T.builtin_ty t in
let dterm name f =
Id.Map.add { name = DStd.Name.simple name; ns = Term } @@
fun env s ->
builtin_term @@
Dolmen_type.Base.term_app1 (module Dl.Typer.T) env s f
in
let ty name ty =
Id.Map.add { name = DStd.Name.simple name; ns = Sort } @@
fun env s ->
builtin_ty @@
Dolmen_type.Base.app0 (module Dl.Typer.T) env s ty
in
let builtin_enum = function
| Ty.Tsum (name, cstrs) as ty_ ->
let ty_cst =
DStd.Expr.Id.mk ~builtin:B.Base
(DStd.Path.global (Hstring.view name))
DStd.Expr.{ arity = 0; alias = No_alias }
in
let cstrs =
List.map (fun c -> DStd.Path.global (Hstring.view c), []) cstrs
in
let _, cstrs = DStd.Expr.Term.define_adt ty_cst [] cstrs in
let dty = DT.apply ty_cst [] in
let add_cstrs map =
List.fold_left (fun map ((c : DE.term_cst), _) ->
let name = get_basename c.path in
Id.Map.add { name = DStd.Name.simple name; ns = Term } (fun env _ ->
builtin_term @@
Dolmen_type.Base.term_app_cst
(module Dl.Typer.T) env c) map)
map cstrs
in
Cache.store_ty (DE.Ty.Const.hash ty_cst) ty_;
dty,
cstrs,
fun map ->
map
|> ty (Hstring.view name) dty
|> add_cstrs
| _ -> assert false
in
let fpa_rounding_mode, rounding_modes, add_rounding_modes =
builtin_enum Fpa_rounding.fpa_rounding_mode
let op ?(tyvars = []) name builtin (args, ret) =
let ty = DT.pi tyvars @@ DT.arrow args ret in
let cst = DE.Id.mk ~name ~builtin (DStd.Path.global name) ty in
Id.Map.add { name = DStd.Name.simple name; ns = Term } @@
fun env _ ->
builtin_term @@
Dolmen_type.Base.term_app_cst
(module Dl.Typer.T) env cst
in
let float_cst =
let ty = DT.(arrow [int; int; fpa_rounding_mode; real] real) in
Expand All @@ -311,15 +363,6 @@ let fpa_builtins =
let float32d x = float32 (mode "NearestTiesToEven") x in
let float64 = float (DE.Term.int "53") (DE.Term.int "1074") in
let float64d x = float64 (mode "NearestTiesToEven") x in
let op ?(tyvars = []) name builtin (args, ret) =
let ty = DT.pi tyvars @@ DT.arrow args ret in
let cst = DE.Id.mk ~name ~builtin (DStd.Path.global name) ty in
Id.Map.add { name = DStd.Name.simple name; ns = Term } @@
fun env _ ->
builtin_term @@
Dolmen_type.Base.term_app_cst
(module Dl.Typer.T) env cst
in
let partial1 name f =
Id.Map.add { name = DStd.Name.simple name; ns = Term } @@
fun env s ->
Expand All @@ -335,7 +378,11 @@ let fpa_builtins =
let is_theory_constant =
let open DT in
let a = Var.mk "alpha" in
op ~tyvars:[a] "is_theory_constant" Is_theory_constant ([of_var a] ->. prop)
op
~tyvars:[a]
"is_theory_constant"
Is_theory_constant
([of_var a] ->. prop)
in
let fpa_builtins =
let open DT in
Expand Down Expand Up @@ -409,24 +456,61 @@ let fpa_builtins =
|> op "not_theory_constant" Not_theory_constant ([real] ->. prop)
|> is_theory_constant
|> op "linear_dependency" Linear_dependency ([real; real] ->. prop)

in
fun env s ->
begin match s with
| Dl.Typer.T.Id id ->
begin
try
Id.Map.find_exn id fpa_builtins env s
with Not_found -> `Not_found
end
| Builtin _ -> `Not_found
end
let search_id id =
try
Id.Map.find_exn id fpa_builtins env s
with Not_found -> `Not_found
in
match s with
| Dl.Typer.T.Id id ->
let new_id = inject_ae_to_smt2 id in
search_id new_id
| Builtin _ -> `Not_found

let smt_fpa_builtins =
let term_app env s f =
Dl.Typer.T.builtin_term @@
Dolmen_type.Base.term_app2 (module Dl.Typer.T) env s f
in
let other_builtins =
Id.Map.empty
|> add_rounding_modes
in
fun env s ->
match s with
| Dl.Typer.T.Id {
ns = Term ;
name = Indexed {
basename = "ae.round" ;
indexes = [ i; j ] } } ->
begin match
int_of_string i,
int_of_string j
with
| n, m -> term_app env s (smt_round n m)
| exception Failure _ -> `Not_found
end
| Dl.Typer.T.Id id -> begin
match Id.Map.find_exn id other_builtins env s with
| e -> e
| exception Not_found -> `Not_found
end
| _ -> `Not_found

(** Concatenation of builtins handlers. *)
let (++) bt1 bt2 =
fun a b ->
match bt1 a b with
| `Not_found -> bt2 a b
| res -> res

let builtins =
fun _st (lang : Typer.lang) ->
match lang with
| `Logic Alt_ergo -> fpa_builtins
| `Logic (Smtlib2 _) -> bv_builtins
| `Logic Alt_ergo -> ae_fpa_builtins
| `Logic (Smtlib2 _) -> bv_builtins ++ smt_fpa_builtins
| _ -> fun _ _ -> `Not_found

(** Translates dolmen locs to Alt-Ergo's locs *)
Expand Down Expand Up @@ -929,6 +1013,13 @@ let mk_add translate sy ty l =
let args = aux_mk_add l in
E.mk_term sy args ty

let mk_rounding fpar =
let name = Fpa_rounding.string_of_rounding_mode fpar in
let ty = Fpa_rounding.fpa_rounding_mode in
let sy =
Sy.Op (Sy.Constr (Hstring.make name)) in
E.mk_term sy [] ty

(** [mk_expr ~loc ~name_base ~toplevel ~decl_kind term]
Builds an Alt-Ergo hashconsed expression from a dolmen term
Expand Down Expand Up @@ -1355,6 +1446,15 @@ let rec mk_expr
| _ -> unsupported "coercion: %a" DE.Term.print term
end
| Float, _ -> op Float
| AERound(i, j), _ ->
let args =
let i = E.Ints.of_int i in
let j = E.Ints.of_int j in
i :: j :: List.map (fun a -> aux_mk_expr a) args in
E.mk_term
(Sy.Op Float)
args
(dty_to_ty term_ty)
| Integer_round, _ -> op Integer_round
| Abs_real, _ -> op Abs_real
| Sqrt_real, _ -> op Sqrt_real
Expand All @@ -1373,6 +1473,11 @@ let rec mk_expr
| Not_theory_constant, _ -> op Not_theory_constant
| Is_theory_constant, _ -> op Is_theory_constant
| Linear_dependency, _ -> op Linear_dependency
| B.RoundNearestTiesToEven, _ -> mk_rounding NearestTiesToEven
| B.RoundNearestTiesToAway, _ -> mk_rounding NearestTiesToAway
| B.RoundTowardPositive, _ -> mk_rounding Up
| B.RoundTowardNegative, _ -> mk_rounding Down
| B.RoundTowardZero, _ -> mk_rounding ToZero
| _, _ -> unsupported "Application Term %a" DE.Term.print term
end

Expand Down
45 changes: 29 additions & 16 deletions src/lib/frontend/typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -259,18 +259,29 @@ module Env = struct
| [ x; y ] -> f x y
| _ -> assert false

let add_builtin_enum = function
| Ty.Tsum (_, cstrs) as ty ->
let enum m h =
let s = Hstring.view h in
MString.add s (`Term (
Symbols.Op (Constr h),
{ args = []; result = ty },
Other
)) m
in
fun m -> List.fold_left enum m cstrs
| _ -> assert false
let add_fpa_enum map =
let ty = Fpa_rounding.fpa_rounding_mode in
match ty with
| Ty.Tsum (_, cstrs) ->
List.fold_left
(fun m c ->
match Fpa_rounding.translate_smt_rounding_mode c with
| None ->
(* The constructors of the type are expected to be AE rounding
modes. *)
assert false
| Some hs ->
MString.add (Hstring.view hs) (`Term (
Symbols.Op (Constr c),
{ args = []; result = ty },
Other
))
m
)
map
cstrs
| _ -> (* Fpa_rounding.fpa_rounding_mode is a sum type. *)
assert false

let find_builtin_cstr ty n =
match ty with
Expand Down Expand Up @@ -298,10 +309,12 @@ module Env = struct
let float prec exp mode x =
TTapp (Symbols.Op Float, [prec; exp; mode; x])
in
let nte = Fpa_rounding.string_of_rounding_mode NearestTiesToEven in
let tname = Fpa_rounding.fpa_rounding_mode_type_name in
let float32 = float (int "24") (int "149") in
let float32d = float32 (mode "NearestTiesToEven") in
let float32d = float32 (mode nte) in
let float64 = float (int "53") (int "1074") in
let float64d = float64 (mode "NearestTiesToEven") in
let float64d = float64 (mode nte) in
let op n op profile =
MString.add n @@ `Term (Symbols.Op op, profile, Other)
in
Expand All @@ -312,8 +325,8 @@ module Env = struct
let any = Ty.fresh_tvar in
let env = {
env with
types = Types.add_builtin env.types "fpa_rounding_mode" rm ;
builtins = add_builtin_enum Fpa_rounding.fpa_rounding_mode env.builtins;
types = Types.add_builtin env.types tname rm ;
builtins = add_fpa_enum env.builtins;
} in
let builtins =
env.builtins
Expand Down
2 changes: 1 addition & 1 deletion src/lib/structures/expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2820,7 +2820,7 @@ let const_view t =
end
| { f = Op (Constr c); ty; _ }
when Ty.equal ty Fpa_rounding.fpa_rounding_mode ->
RoundingMode (Fpa_rounding.rounding_mode_of_hs c)
RoundingMode (Fpa_rounding.rounding_mode_of_smt_hs c)
| _ -> Fmt.failwith "unsupported constant: %a" print t

let int_view t =
Expand Down
Loading

0 comments on commit 0e3fc36

Please sign in to comment.