Skip to content

Commit

Permalink
Add binder expression
Browse files Browse the repository at this point in the history
  • Loading branch information
filipeom committed Sep 13, 2024
1 parent 77e21fd commit 6c5e04e
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 11 deletions.
36 changes: 34 additions & 2 deletions src/ast/expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@

open Ty

type binder =
| Forall
| Exists
| Let_in

type t = expr Hc.hash_consed

and expr =
Expand All @@ -37,6 +42,12 @@ and expr =
| Naryop of Ty.t * naryop * t list
| Extract of t * int * int
| Concat of t * t
| Binder of binder * t list * t

let equal_binder a b =
match (a, b) with
| Forall, Forall | Exists, Exists | Let_in, Let_in -> true
| (Forall | Exists | Let_in), _ -> false

module Expr = struct
type t = expr
Expand Down Expand Up @@ -71,8 +82,10 @@ module Expr = struct
| Extract (e1, h1, l1), Extract (e2, h2, l2) ->
phys_equal e1 e2 && h1 = h2 && l1 = l2
| Concat (e1, e3), Concat (e2, e4) -> phys_equal e1 e2 && phys_equal e3 e4
| Binder (binder1, vars1, e1), Binder (binder2, vars2, e2) ->
equal_binder binder1 binder2 && list_eq vars1 vars2 && phys_equal e1 e2
| ( ( Val _ | Ptr _ | Symbol _ | List _ | App _ | Unop _ | Binop _ | Triop _
| Relop _ | Cvtop _ | Naryop _ | Extract _ | Concat _ )
| Relop _ | Cvtop _ | Naryop _ | Extract _ | Concat _ | Binder _ )
, _ ) ->
false

Expand All @@ -92,6 +105,7 @@ module Expr = struct
| Naryop (ty, op, es) -> h (ty, op, es)
| Extract (e, hi, lo) -> h (e.tag, hi, lo)
| Concat (e1, e2) -> h (e1.tag, e2.tag)
| Binder (b, vars, e) -> h (b, vars, e.tag)
end

module Hc = Hc.Make [@inlined hint] (Expr)
Expand Down Expand Up @@ -139,6 +153,7 @@ let rec ty (hte : t) : Ty.t =
| Ty_bitv n1, Ty_bitv n2 -> Ty_bitv (n1 + n2)
| t1, t2 ->
Fmt.failwith "Invalid concat of (%a) with (%a)" Ty.pp t1 Ty.pp t2 )
| Binder (_, _, e) -> ty e

let rec is_symbolic (v : t) : bool =
match view v with
Expand All @@ -156,6 +171,7 @@ let rec is_symbolic (v : t) : bool =
| Naryop (_, _, vs) -> List.exists is_symbolic vs
| Extract (e, _, _) -> is_symbolic e
| Concat (e1, e2) -> is_symbolic e1 || is_symbolic e2
| Binder (_, _, e) -> is_symbolic e

let get_symbols (hte : t list) =
let tbl = Hashtbl.create 64 in
Expand Down Expand Up @@ -183,6 +199,9 @@ let get_symbols (hte : t list) =
| Concat (e1, e2) ->
symbols e1;
symbols e2
| Binder (_, vars, e) ->
List.iter symbols vars;
symbols e
in
List.iter symbols hte;
Hashtbl.fold (fun k () acc -> k :: acc) tbl []
Expand All @@ -205,11 +224,16 @@ let negate_relop (hte : t) : (t, string) Result.t =
Result.map make e

module Pp = struct
let pp_binder fmt = function
| Forall -> Fmt.string fmt "forall"
| Exists -> Fmt.string fmt "exists"
| Let_in -> Fmt.string fmt "let"

let rec pp fmt (hte : t) =
match view hte with
| Val v -> Value.pp fmt v
| Ptr { base; offset } -> Fmt.pf fmt "(Ptr (i32 %ld) %a)" base pp offset
| Symbol s -> Symbol.pp fmt s
| Symbol s -> Fmt.pf fmt "@[<hov 1>%a@]" Symbol.pp s
| List v -> Fmt.pf fmt "@[<hov 1>[%a]@]" (Fmt.list ~sep:Fmt.comma pp) v
| App (s, v) ->
Fmt.pf fmt "@[<hov 1>(%a@ %a)@]" Symbol.pp s
Expand All @@ -233,6 +257,9 @@ module Pp = struct
| Extract (e, h, l) ->
Fmt.pf fmt "@[<hov 1>(extract@ %a@ %d@ %d)@]" pp e l h
| Concat (e1, e2) -> Fmt.pf fmt "@[<hov 1>(++@ %a@ %a)@]" pp e1 pp e2
| Binder (b, vars, e) ->
Fmt.pf fmt "@[<hov 1>(%a@ (%a)@ %a)@]" pp_binder b
(Fmt.list ~sep:Fmt.sp pp) vars pp e

let pp_list fmt (es : t list) = Fmt.hovbox (Fmt.list ~sep:Fmt.comma pp) fmt es

Expand Down Expand Up @@ -269,6 +296,8 @@ let ptr base offset = make (Ptr { base; offset })

let app symbol args = make (App (symbol, args))

let let_in vars expr = make (Binder (Let_in, vars, expr))

let unop' (ty : Ty.t) (op : unop) (hte : t) : t = make (Unop (ty, op, hte))
[@@inline]

Expand Down Expand Up @@ -516,6 +545,9 @@ let rec simplify_expr ?(rm_extract = true) (hte : t) : t =
let msb = simplify_expr ~rm_extract:false e1 in
let lsb = simplify_expr ~rm_extract:false e2 in
concat msb lsb
| Binder _ ->
(* Not simplifying anything atm *)
hte

let simplify (hte : t) : t =
let rec loop x =
Expand Down
11 changes: 10 additions & 1 deletion src/ast/expr.mli
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
(* along with this program. If not, see <https://www.gnu.org/licenses/>. *)
(***************************************************************************)

(* Dolmen's binders *)
type binder =
| Forall
| Exists
| Let_in

(** Term definitions of the abstract syntax *)
type t = expr Hc.hash_consed

Expand All @@ -36,6 +42,7 @@ and expr =
| Naryop of Ty.t * Ty.naryop * t list
| Extract of t * int * int
| Concat of t * t
| Binder of binder * t list * t

val equal : t -> t -> bool

Expand Down Expand Up @@ -71,7 +78,9 @@ val ptr : int32 -> t -> t

val symbol : Symbol.t -> t

val app: Symbol.t -> t list -> t
val app : Symbol.t -> t list -> t

val let_in : t list -> t -> t

(** Smart unop constructor, applies simplifications at constructor level *)
val unop : Ty.t -> Ty.unop -> t -> t
Expand Down
26 changes: 21 additions & 5 deletions src/ast/rewrite.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
{e declare-const}.
2. Propagate the correct theory encoding for [Unop], [Binop], [Relop], and
[Triop]. *)
[Triop].
module Type_map = Map.Make (Symbol)
3. Inlines [Let_in] binders into a single big expr *)

module Symb_map = Map.Make (Symbol)

(* TODO: Add proper logs *)
let debug = false

let debug fmt k = if debug then k (Fmt.epr fmt)
Expand All @@ -23,13 +26,19 @@ let rewrite_ty unknown_ty tys =
| Ty.Ty_none, _ -> assert false
| ty, _ -> ty

(** Inlines [Let_in] bindings into a single expr *)
let rewrite_let_in _expr_map hte =
debug "rewrite_let_: %a@." (fun k -> k Expr.pp hte);
assert false

(** Propagates types in [type_map] or inlines [Let_in] binders *)
let rec rewrite_expr type_map hte =
debug "rewrite_expr: %a@." (fun k -> k Expr.pp hte);
match Expr.view hte with
| Val _ -> hte
| Ptr { base; offset } -> Expr.ptr base (rewrite_expr type_map offset)
| Symbol sym -> (
match Type_map.find sym type_map with
match Symb_map.find sym type_map with
| exception Not_found -> Fmt.failwith "Undefined symbol: %a" Symbol.pp sym
| ty -> Expr.symbol { sym with ty } )
| List htes -> Expr.make (List (List.map (rewrite_expr type_map) htes))
Expand Down Expand Up @@ -70,7 +79,14 @@ let rec rewrite_expr type_map hte =
let hte1 = rewrite_expr type_map hte1 in
let hte2 = rewrite_expr type_map hte2 in
Expr.make (Concat (hte1, hte2))
| Binder (Let_in, _, _) ->
(* First, we match on the outer let_in and rewrite it as a single expr *)
let hte = rewrite_let_in Symb_map.empty hte in
(* Then, we rewrite the types of the expr *)
rewrite_expr type_map hte
| Binder (_, _, _) -> assert false

(** Acccumulates types of symbols in [type_map] and calls rewrite_expr *)
let rewrite_cmd type_map cmd =
debug " rewrite_cmd: %a@." (fun k -> k Ast.pp cmd);
match cmd with
Expand All @@ -80,7 +96,7 @@ let rewrite_cmd type_map cmd =
| Check_sat htes ->
let htes = List.map (rewrite_expr type_map) htes in
(type_map, Check_sat htes)
| Declare_const { id; sort } as cmd -> (Type_map.add id sort.ty type_map, cmd)
| Declare_const { id; sort } as cmd -> (Symb_map.add id sort.ty type_map, cmd)
| Get_value htes ->
let htes = List.map (rewrite_expr type_map) htes in
(type_map, Get_value htes)
Expand All @@ -92,6 +108,6 @@ let rewrite script =
(fun (type_map, cmds) cmd ->
let type_map, new_cmd = rewrite_cmd type_map cmd in
(type_map, new_cmd :: cmds) )
(Type_map.empty, []) script
(Symb_map.empty, []) script
in
List.rev cmds
2 changes: 1 addition & 1 deletion src/colibri2_mappings.default.ml
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ module Fresh = struct
let e1' = aux e1
and e2' = aux e2 in
DTerm.Bitv.concat e1' e2'
| List _ | App _ -> assert false
| List _ | App _ | Binder _ -> assert false
(* | Quantifier (t, vars, body, patterns) -> (
let body' = aux body in
let encode_pattern (p : t list) =
Expand Down
2 changes: 1 addition & 1 deletion src/mappings.ml
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct
let ctx, e1 = encode_expr ctx e1 in
let ctx, e2 = encode_expr ctx e2 in
(ctx, M.Bitv.concat e1 e2)
| List _ | App _ -> assert false
| List _ | App _ | Binder _ -> assert false

let encode_exprs ctx (es : Expr.t list) : symbol_ctx * M.term list =
List.fold_left
Expand Down
2 changes: 1 addition & 1 deletion src/parser/smtlib.ml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ module Term = struct
(* Ids can only be symbols. Any other expr here is super wrong *)
assert false

let letand ?loc:_ = assert false
let letand ?loc:_ (vars : t list) (expr : t) : t = Expr.let_in vars expr

let forall ?loc:_ = assert false

Expand Down

0 comments on commit 6c5e04e

Please sign in to comment.