Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: exporting alt-ergo FPA built-in primitives #876

Merged
merged 16 commits into from
Oct 30, 2023
240 changes: 177 additions & 63 deletions src/lib/frontend/d_cnf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ end
(** Builtins *)
type _ DStd.Builtin.t +=
| Float
| AERound of int * int
(** Equivalent of Float for the SMT2 format. *)
| Integer_round
| Abs_real
| Sqrt_real
Expand All @@ -193,6 +195,49 @@ type _ DStd.Builtin.t +=
(* Internal use for semantic triggers -- do not expose outside of theories *)
| Not_theory_constant | Is_theory_constant | Linear_dependency

let builtin_term t = Dl.Typer.T.builtin_term t

let builtin_ty t = Dl.Typer.T.builtin_ty t

let ty name ty =
Id.Map.add { name = DStd.Name.simple name; ns = Sort } @@
fun env s ->
builtin_ty @@
Dolmen_type.Base.app0 (module Dl.Typer.T) env s ty

let builtin_enum = function
| Ty.Tsum (name, cstrs) as ty_ ->
let ty_cst =
DStd.Expr.Id.mk ~builtin:B.Base
(DStd.Path.global (Hstring.view name))
DStd.Expr.{ arity = 0; alias = No_alias }
in
let cstrs =
List.map (fun c -> DStd.Path.global (Hstring.view c), []) cstrs
in
let _, cstrs = DStd.Expr.Term.define_adt ty_cst [] cstrs in
let dty = DT.apply ty_cst [] in
let add_cstrs map =
List.fold_left (fun map ((c : DE.term_cst), _) ->
let name = get_basename c.path in
Id.Map.add { name = DStd.Name.simple name; ns = Term } (fun env _ ->
builtin_term @@
Dolmen_type.Base.term_app_cst
(module Dl.Typer.T) env c) map)
map cstrs
in
Cache.store_ty (DE.Ty.Const.hash ty_cst) ty_;
dty,
cstrs,
fun map ->
map
|> ty (Hstring.view name) dty
|> add_cstrs
| _ -> assert false

let fpa_rounding_mode, rounding_modes, add_rounding_modes =
builtin_enum Fpa_rounding.fpa_rounding_mode

module Const = struct
open DE

Expand All @@ -207,6 +252,15 @@ module Const = struct
let name = "int2bv" in
Id.mk ~name ~builtin:(Int2BV n)
(DStd.Path.global name) Ty.(arrow [int] (bitv n)))

let smt_round =
with_cache (fun (n, m) ->
let name = "ae.round" in
Id.mk
~name
~builtin:(AERound (n, m))
(DStd.Path.global name)
Ty.(arrow [fpa_rounding_mode; real] real))
end

let bv2nat t =
Expand All @@ -220,6 +274,9 @@ let bv2nat t =
let int2bv n t =
DE.Term.apply_cst (Const.int2bv n) [] [t]

let smt_round n m rm t =
DE.Term.apply_cst (Const.smt_round (n, m)) [] [rm; t]

let bv_builtins env s =
let term_app1 f =
Dl.Typer.T.builtin_term @@
Expand All @@ -241,54 +298,49 @@ let bv_builtins env s =
end
| _ -> `Not_found

let fpa_builtins =
(** Takes a dolmen identifier [id] and injects it in Alt-Ergo's registered
identifiers.
It transforms "fpa_rounding_mode", the Alt-Ergo builtin type into the SMT2
rounding type "RoundingMode". Also injects each constructor into their SMT2
equivalent *)
let inject_identifier id =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation comment is useful but calling this something explicit like inject_ae_to_smt2 would be clearer than inject_identifier which sounds generic.

match id with
| Id.{name = Simple n; _} ->
begin
if String.equal n Fpa_rounding.fpa_rounding_mode_ae_type_name then
(* Injecting the type name as the SMT2 Type name. *)
let name =
Dolmen_std.Name.simple Fpa_rounding.fpa_rounding_mode_type_name
in
{id with name}
else
match Fpa_rounding.rounding_mode_of_ae_hs (Hstring.make n) with
| rm ->
let name =
Dolmen_std.Name.simple (Fpa_rounding.string_of_rounding_mode rm)
in
{id with name}
| exception (Failure _) ->
id
end
| id -> id

let ae_fpa_builtins =
let (->.) args ret = (args, ret) in
let builtin_term t = Dl.Typer.T.builtin_term t in
let builtin_ty t = Dl.Typer.T.builtin_ty t in
let dterm name f =
Id.Map.add { name = DStd.Name.simple name; ns = Term } @@
fun env s ->
builtin_term @@
Dolmen_type.Base.term_app1 (module Dl.Typer.T) env s f
in
let ty name ty =
Id.Map.add { name = DStd.Name.simple name; ns = Sort } @@
fun env s ->
builtin_ty @@
Dolmen_type.Base.app0 (module Dl.Typer.T) env s ty
in
let builtin_enum = function
| Ty.Tsum (name, cstrs) as ty_ ->
let ty_cst =
DStd.Expr.Id.mk ~builtin:B.Base
(DStd.Path.global (Hstring.view name))
DStd.Expr.{ arity = 0; alias = No_alias }
in
let cstrs =
List.map (fun c -> DStd.Path.global (Hstring.view c), []) cstrs
in
let _, cstrs = DStd.Expr.Term.define_adt ty_cst [] cstrs in
let dty = DT.apply ty_cst [] in
let add_cstrs map =
List.fold_left (fun map ((c : DE.term_cst), _) ->
let name = get_basename c.path in
Id.Map.add { name = DStd.Name.simple name; ns = Term } (fun env _ ->
builtin_term @@
Dolmen_type.Base.term_app_cst
(module Dl.Typer.T) env c) map)
map cstrs
in
Cache.store_ty (DE.Ty.Const.hash ty_cst) ty_;
dty,
cstrs,
fun map ->
map
|> ty (Hstring.view name) dty
|> add_cstrs
| _ -> assert false
in
let fpa_rounding_mode, rounding_modes, add_rounding_modes =
builtin_enum Fpa_rounding.fpa_rounding_mode
let op ?(tyvars = []) name builtin (args, ret) =
let ty = DT.pi tyvars @@ DT.arrow args ret in
let cst = DE.Id.mk ~name ~builtin (DStd.Path.global name) ty in
Id.Map.add { name = DStd.Name.simple name; ns = Term } @@
fun env _ ->
builtin_term @@
Dolmen_type.Base.term_app_cst
(module Dl.Typer.T) env cst
in
let float_cst =
let ty = DT.(arrow [int; int; fpa_rounding_mode; real] real) in
Expand All @@ -311,15 +363,6 @@ let fpa_builtins =
let float32d x = float32 (mode "NearestTiesToEven") x in
let float64 = float (DE.Term.int "53") (DE.Term.int "1074") in
let float64d x = float64 (mode "NearestTiesToEven") x in
let op ?(tyvars = []) name builtin (args, ret) =
let ty = DT.pi tyvars @@ DT.arrow args ret in
let cst = DE.Id.mk ~name ~builtin (DStd.Path.global name) ty in
Id.Map.add { name = DStd.Name.simple name; ns = Term } @@
fun env _ ->
builtin_term @@
Dolmen_type.Base.term_app_cst
(module Dl.Typer.T) env cst
in
let partial1 name f =
Id.Map.add { name = DStd.Name.simple name; ns = Term } @@
fun env s ->
Expand All @@ -335,7 +378,11 @@ let fpa_builtins =
let is_theory_constant =
let open DT in
let a = Var.mk "alpha" in
op ~tyvars:[a] "is_theory_constant" Is_theory_constant ([of_var a] ->. prop)
op
~tyvars:[a]
"is_theory_constant"
Is_theory_constant
([of_var a] ->. prop)
in
let fpa_builtins =
let open DT in
Expand Down Expand Up @@ -409,24 +456,61 @@ let fpa_builtins =
|> op "not_theory_constant" Not_theory_constant ([real] ->. prop)
|> is_theory_constant
|> op "linear_dependency" Linear_dependency ([real; real] ->. prop)

in
fun env s ->
begin match s with
| Dl.Typer.T.Id id ->
begin
try
Id.Map.find_exn id fpa_builtins env s
with Not_found -> `Not_found
end
| Builtin _ -> `Not_found
end
let search_id id =
try
Id.Map.find_exn id fpa_builtins env s
with Not_found -> `Not_found
in
match s with
| Dl.Typer.T.Id id ->
let new_id = inject_identifier id in
search_id new_id
| Builtin _ -> `Not_found

let smt_fpa_builtins =
let term_app env s f =
Dl.Typer.T.builtin_term @@
Dolmen_type.Base.term_app2 (module Dl.Typer.T) env s f
in
let other_builtins =
Id.Map.empty
|> add_rounding_modes
in
fun env s ->
match s with
| Dl.Typer.T.Id {
ns = Term ;
name = Indexed {
basename = "ae.round" ;
indexes = [ i; j ] } } ->
begin match
int_of_string i,
int_of_string j
with
| n, m -> term_app env s (smt_round n m)
| exception Failure _ -> `Not_found
end
| Dl.Typer.T.Id id -> begin
match Id.Map.find_exn id other_builtins env s with
| e -> e
| exception Not_found -> `Not_found
end
| _ -> `Not_found

(** Concatenation of builtins handlers. *)
let (++) bt1 bt2 =
fun a b ->
match bt1 a b with
| `Not_found -> bt2 a b
| res -> res

let builtins =
fun _st (lang : Typer.lang) ->
match lang with
| `Logic Alt_ergo -> fpa_builtins
| `Logic (Smtlib2 _) -> bv_builtins
| `Logic Alt_ergo -> ae_fpa_builtins
| `Logic (Smtlib2 _) -> bv_builtins ++ smt_fpa_builtins
| _ -> fun _ _ -> `Not_found

(** Translates dolmen locs to Alt-Ergo's locs *)
Expand Down Expand Up @@ -929,6 +1013,13 @@ let mk_add translate sy ty l =
let args = aux_mk_add l in
E.mk_term sy args ty

let mk_rounding fpar =
let name = Fpa_rounding.string_of_rounding_mode fpar in
let ty = Fpa_rounding.fpa_rounding_mode in
let sy =
Sy.Op (Sy.Constr (Hstring.make name)) in
E.mk_term sy [] ty

(** [mk_expr ~loc ~name_base ~toplevel ~decl_kind term]

Builds an Alt-Ergo hashconsed expression from a dolmen term
Expand Down Expand Up @@ -1355,6 +1446,15 @@ let rec mk_expr
| _ -> unsupported "coercion: %a" DE.Term.print term
end
| Float, _ -> op Float
| AERound(i, j), _ ->
let args =
let i = E.Ints.of_int i in
let j = E.Ints.of_int j in
i :: j :: List.map (fun a -> aux_mk_expr a) args in
E.mk_term
(Sy.Op Float)
args
(dty_to_ty term_ty)
| Integer_round, _ -> op Integer_round
| Abs_real, _ -> op Abs_real
| Sqrt_real, _ -> op Sqrt_real
Expand All @@ -1373,6 +1473,20 @@ let rec mk_expr
| Not_theory_constant, _ -> op Not_theory_constant
| Is_theory_constant, _ -> op Is_theory_constant
| Linear_dependency, _ -> op Linear_dependency
| (B.RoundNearestTiesToEven
| B.RoundNearestTiesToAway
| B.RoundTowardPositive
| B.RoundTowardNegative
| B.RoundTowardZero as b), _ ->
let fpa_rounding = match b with
B.RoundNearestTiesToEven -> Fpa_rounding.NearestTiesToEven
| B.RoundNearestTiesToAway -> NearestTiesToAway
| B.RoundTowardPositive -> Up
| B.RoundTowardNegative -> Down
| B.RoundTowardZero -> ToZero
| _ -> assert false
in
mk_rounding fpa_rounding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with mk_rounding now the following is shorter and clearer:

Suggested change
| (B.RoundNearestTiesToEven
| B.RoundNearestTiesToAway
| B.RoundTowardPositive
| B.RoundTowardNegative
| B.RoundTowardZero as b), _ ->
let fpa_rounding = match b with
B.RoundNearestTiesToEven -> Fpa_rounding.NearestTiesToEven
| B.RoundNearestTiesToAway -> NearestTiesToAway
| B.RoundTowardPositive -> Up
| B.RoundTowardNegative -> Down
| B.RoundTowardZero -> ToZero
| _ -> assert false
in
mk_rounding fpa_rounding
| B.RoundNearestTiesToEven, _ -> mk_rounding NearestTiesToEven
| B.RoundNearestTiesToAway, _ -> mk_rounding NearestTiesToAway
| B.RoundTowardPositive, _ -> mk_rounding Up
| B.RoundTowardNegative, _ -> mk_rounding Down
| B.RoundTowardZero, _ -> mk_rounding ToZero

| _, _ -> unsupported "Application Term %a" DE.Term.print term
end

Expand Down
Loading
Loading