Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(BV, CP): Add propagators for bvudiv and bvurem #1084

Merged
merged 3 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/lib/reasoners/bitv.ml
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ module Shostak(X : ALIEN) = struct
| Op (
Concat | Extract _ | BV2Nat
| BVnot | BVand | BVor | BVxor
| BVadd | BVsub | BVmul)
| BVadd | BVsub | BVmul | BVudiv | BVurem)
-> true
| _ -> false

Expand Down Expand Up @@ -408,7 +408,7 @@ module Shostak(X : ALIEN) = struct
match E.term_view t with
| { f = Op (
BVand | BVor | BVxor
| BVadd | BVsub | BVmul
| BVadd | BVsub | BVmul | BVudiv | BVurem
); _ } ->
X.term_embed t, []
| _ -> X.make t
Expand Down
50 changes: 49 additions & 1 deletion src/lib/reasoners/bitv_rel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,17 @@ module Constraint : sig
val bvmul : X.r -> X.r -> X.r -> t
(** [bvmul r x y] is the constraint [r = x * y] *)

val bvudiv : X.r -> X.r -> X.r -> t
(** [bvudir r x y] is the constraint [r = x / y]
This uses the convention that [x / 0] is [-1]. *)

val bvurem : X.r -> X.r -> X.r -> t
(** [bvurem r x y] is the constraint [r = x % y], where [x % y] is the
remainder of euclidean division.
This uses the convention that [x % 0] is [x]. *)

val bvule : X.r -> X.r -> t

val bvugt : X.r -> X.r -> t
Expand All @@ -261,14 +272,16 @@ end = struct
(* Bitwise operations *)
| Band | Bor | Bxor
(* Arithmetic operations *)
| Badd | Bmul
| Badd | Bmul | Budiv | Burem

let pp_binop ppf = function
| Band -> Fmt.pf ppf "bvand"
| Bor -> Fmt.pf ppf "bvor"
| Bxor -> Fmt.pf ppf "bvxor"
| Badd -> Fmt.pf ppf "bvadd"
| Bmul -> Fmt.pf ppf "bvmul"
| Budiv -> Fmt.pf ppf "bvudiv"
| Burem -> Fmt.pf ppf "bvurem"

let equal_binop op1 op2 =
match op1, op2 with
Expand All @@ -285,11 +298,18 @@ end = struct
| Badd, _ | _, Badd -> false

| Bmul, Bmul -> true
| Bmul, _ | _, Bmul -> false

| Budiv, Budiv -> true
| Budiv, _ | _, Budiv -> false

| Burem, Burem -> true

let hash_binop : binop -> int = Hashtbl.hash

let is_commutative = function
| Band | Bor | Bxor | Badd | Bmul -> true
| Budiv | Burem -> false

let propagate_binop ~ex dx op dy dz =
let open Bitlist_domains.Ephemeral in
Expand Down Expand Up @@ -326,6 +346,12 @@ end = struct
| Bmul -> (* Only forward propagation for now *)
update ~ex dx (Bitlist.mul !!dy !!dz)

| Budiv -> (* No bitlist propagation for now *)
()

| Burem -> (* No bitlist propagation for now *)
()

let propagate_interval_binop ~ex sz dr op dx dy =
let open Interval_domains.Ephemeral in
let norm i = Intervals.Int.extract i ~ofs:0 ~len:sz in
Expand All @@ -338,6 +364,12 @@ end = struct
| Bmul -> (* Only forward propagation for now *)
update ~ex dr @@ norm @@ Intervals.Int.mul !!dx !!dy

| Budiv -> (* Only forward propagation for now *)
update ~ex dr @@ Intervals.Int.bvudiv ~size:sz !!dx !!dy

| Burem -> (* Only forward propagation for now *)
update ~ex dr @@ Intervals.Int.bvurem !!dx !!dy

| Band | Bor | Bxor ->
(* No interval propagation for bitwise operators yet *)
()
Expand Down Expand Up @@ -540,6 +572,8 @@ end = struct
(* r = x - y <-> x = r + y *)
bvadd x r y
let bvmul = cbinop Bmul
let bvudiv = cbinop Budiv
let bvurem = cbinop Burem

let crel r = hcons @@ Crel r

Expand Down Expand Up @@ -703,6 +737,16 @@ end = struct
| Bxor -> cast ty @@ Z.logxor x y
| Badd -> cast ty @@ Z.add x y
| Bmul -> cast ty @@ Z.mul x y
| Budiv ->
if Z.equal y Z.zero then
cast ty Z.minus_one
else
cast ty @@ Z.div x y
| Burem ->
if Z.equal y Z.zero then
cast ty x
else
cast ty @@ Z.rem x y

(* Constant simplification rules for binary operators.
Expand Down Expand Up @@ -747,6 +791,8 @@ end = struct
add_mul_const acts r x (value y)
| Bmul -> false

| Budiv | Burem -> false

(* Algebraic rewrite rules for binary operators.
Rules based on constant simplifications are in [rw_binop_const]. *)
Expand Down Expand Up @@ -816,6 +862,8 @@ let extract_binop =
| BVadd -> Some bvadd
| BVsub -> Some bvsub
| BVmul -> Some bvmul
| BVudiv -> Some bvudiv
| BVurem -> Some bvurem
| _ -> None

let extract_constraints bcs uf r t =
Expand Down
39 changes: 39 additions & 0 deletions src/lib/reasoners/intervals.ml
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,45 @@ module Int = struct

let lognot u =
trace1 "lognot" u @@ map_strict_dec ZEuclideanType.lognot u

let bvudiv ~size:sz u1 u2 =
(* [bvudiv] is euclidean division where division by zero is -1 (as an
integer of width [sz], so 2^sz - 1) *)
let mone = Z.extract Z.minus_one 0 sz in
ediv ~div0:(Interval.of_bounds (Closed mone) (Closed mone)) u1 u2

let bvurem u1 u2 =
(* In the following, [x] is the implicit variable associated with [u1] and
[y] the implicit variable associated with [u2]. *)
of_set_nonempty @@
map_to_set (fun i2 ->
if ZEuclideanType.equal i2.ub ZEuclideanType.zero then
Halbaroth marked this conversation as resolved.
Show resolved Hide resolved
(* [y] is always zero -> identity *)
Halbaroth marked this conversation as resolved.
Show resolved Hide resolved
map_to_set interval_set u1
else if ZEuclideanType.compare i2.ub ZEuclideanType.zero < 0 then
(* Safety check -- bvurem only makes sense if [u2] is nonnegative. *)
invalid_arg "bvurem"
else
map_to_set (fun i1 ->
if ZEuclideanType.compare i1.ub i2.lb < 0 then
(* x < y : bvurem is the identity *)
Halbaroth marked this conversation as resolved.
Show resolved Hide resolved
interval_set i1
else if (
ZEuclideanType.equal i2.lb ZEuclideanType.zero
) then
(* The range [0, i1.ub] is always valid; it is also the best we
can do if [y] can be [0]. *)
interval_set { i1 with lb = ZEuclideanType.zero }
else
(* y is non-zero; we have both [x % y < y] and [x % y <= x] so
take the min of these upper bounds. *)
let ub =
if ZEuclideanType.compare i1.ub i2.ub < 0 then i1.ub
else ZEuclideanType.pred i2.ub
in
interval_set { lb = ZEuclideanType.zero ; ub }
) u1
) u2
end

module Legacy = struct
Expand Down
14 changes: 14 additions & 0 deletions src/lib/reasoners/intervals.mli
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ module Int : sig
{b Note}: The interval [s] must be an integer interval, but is
allowed to be unbounded (in which case [extract s i j] returns the
full interval [[0, 2^(j - i + 1) - 1]]). *)

val bvudiv : size:int -> t -> t -> t
(** [bvudiv sz s t] computes an overapproximation of integer division for
bit-vectors of width [sz] as defined in the FixedSizeBitVectors SMT-LIB
theory, i.e. where [bvudiv n 0] is [2^sz - 1].
[s] and [t] must be within the [0, 2^sz - 1] range. *)

val bvurem : t -> t -> t
(** [bvurem sz s t] computes an overapproximation of integer remainder for
bit-vectors of width [sz] as defined in the FixedSizeBitVectors SMT-LIB
theory, i.e. where [bvurem n 0] is [n].
[s] and [t] must be within the [0, 2^sz - 1] range. *)
end

module Legacy : sig
Expand Down
12 changes: 2 additions & 10 deletions src/lib/structures/expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3143,16 +3143,8 @@ module BV = struct
let bvsub s t = mk_term (Op BVsub) [s; t] (type_info s)
let bvneg s = bvsub (of_bigint_like s Z.zero) s
let bvmul s t = mk_term (Op BVmul) [s; t] (type_info s)
let bvudiv s t =
let m = size2 s t in
ite (eq (bv2nat t) Ints.(~$0))
(bvones m)
(int2bv m Ints.(bv2nat s / bv2nat t))
let bvurem s t =
let m = size2 s t in
ite (eq (bv2nat t) Ints.(~$0))
s
(int2bv m Ints.(bv2nat s mod bv2nat t))
let bvudiv s t = mk_term (Op BVudiv) [s; t] (type_info s)
let bvurem s t = mk_term (Op BVurem) [s; t] (type_info s)
let bvsdiv s t =
let m = size2 s t in
let msb_s = extract (m - 1) (m - 1) s in
Expand Down
8 changes: 6 additions & 2 deletions src/lib/structures/symbols.ml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type operator =
| Concat
| Extract of int * int (* lower bound * upper bound *)
| BVnot | BVand | BVor | BVxor
| BVadd | BVsub | BVmul
| BVadd | BVsub | BVmul | BVudiv | BVurem
| Int2BV of int | BV2Nat
(* FP *)
| Float
Expand Down Expand Up @@ -192,7 +192,7 @@ let compare_operators op1 op2 =
| Sqrt_real_excess | Min_real | Min_int | Max_real | Max_int
| Integer_log2 | Pow | Integer_round
| BVnot | BVand | BVor | BVxor
| BVadd | BVsub | BVmul
| BVadd | BVsub | BVmul | BVudiv | BVurem
| Int2BV _ | BV2Nat
| Not_theory_constant | Is_theory_constant | Linear_dependency
| Constr _ | Destruct _ | Tite) -> assert false
Expand Down Expand Up @@ -354,6 +354,8 @@ module AEPrinter = struct
| BVadd -> Fmt.pf ppf "bvadd"
| BVsub -> Fmt.pf ppf "bvsub"
| BVmul -> Fmt.pf ppf "bvmul"
| BVudiv -> Fmt.pf ppf "bvudiv"
| BVurem -> Fmt.pf ppf "bvurem"

(* ArraysEx theory *)
| Get -> Fmt.pf ppf "get"
Expand Down Expand Up @@ -457,6 +459,8 @@ module SmtPrinter = struct
| BVadd -> Fmt.pf ppf "bvadd"
| BVsub -> Fmt.pf ppf "bvsub"
| BVmul -> Fmt.pf ppf "bvmul"
| BVudiv -> Fmt.pf ppf "bvudiv"
| BVurem -> Fmt.pf ppf "bvurem"

(* ArraysEx theory *)
| Get -> Fmt.pf ppf "select"
Expand Down
2 changes: 1 addition & 1 deletion src/lib/structures/symbols.mli
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type operator =
| Concat
| Extract of int * int (* lower bound * upper bound *)
| BVnot | BVand | BVor | BVxor
| BVadd | BVsub | BVmul
| BVadd | BVsub | BVmul | BVudiv | BVurem
| Int2BV of int | BV2Nat
(* FP *)
| Float
Expand Down
38 changes: 38 additions & 0 deletions tests/bitvec_tests.ml
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,18 @@ let test_bitlist_binop ~count sz zop bop =
(IntSet.map2 zop (of_bitlist s) (of_bitlist t))
(of_bitlist u))

let test_interval_binop ~count sz zop iop =
Test.make ~count
~print:Print.(
pair
(Fmt.to_to_string Intervals.Int.pp)
(Fmt.to_to_string Intervals.Int.pp))
Gen.(pair (intervals sz) (intervals sz))
(fun (s, t) ->
IntSet.subset
(IntSet.map2 zop (of_interval s) (of_interval t))
(of_interval (iop s t)))

let zmul sz a b =
Z.extract (Z.mul a b) 0 sz

Expand All @@ -263,3 +275,29 @@ let test_bitlist_mul sz =

let () =
Test.check_exn (test_bitlist_mul 3)

let zudiv sz a b =
if Z.equal b Z.zero then
Z.extract Z.minus_one 0 sz
else
Z.div a b

let test_interval_bvudiv sz =
test_interval_binop ~count:1_000
sz (zudiv sz) (Intervals.Int.bvudiv ~size:sz)

let () =
Test.check_exn (test_interval_bvudiv 3)

let zurem a b =
if Z.equal b Z.zero then
a
else
Z.rem a b

let test_interval_bvurem sz =
test_interval_binop ~count:1_000
sz zurem Intervals.Int.bvurem

let () =
Test.check_exn (test_interval_bvurem 3)
Loading