From d679d20900235a084d4adbd16ca0f84b247d84db Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Basile=20Cl=C3=A9ment?= <basile.clement@ocamlpro.com>
Date: Mon, 10 Jun 2024 21:40:12 +0200
Subject: [PATCH] feat(BV): Do not store width in Bitlist

This patch rewrites the Bitlist module to represent infinite-width
bit-vectors rather than fixed-width bit-vectors, making it able to
represent unbounded integers. This improves the symmetry with the
interval domain (which also applies to unbounded integers) and is
intended to simplify the implementation of BV-to-int and int-to-BV
conversions.

In order to avoid having to use negative numbers in bitlists, the
internal representation is changed from a pair of (bits equal to [1],
bits equal to [0]) masks to a pair (bits equal to [1], unknown bits)
masks. This change should also be good for memory consumption, as we no
longer keep two copies of the value around when all bits are known.
---
 src/lib/reasoners/bitlist.ml  | 314 +++++++++++++++++++---------------
 src/lib/reasoners/bitlist.mli |  81 +++++----
 src/lib/reasoners/bitv_rel.ml | 198 ++++++++++-----------
 tests/bitvec_tests.ml         |  44 ++++-
 4 files changed, 350 insertions(+), 287 deletions(-)

diff --git a/src/lib/reasoners/bitlist.ml b/src/lib/reasoners/bitlist.ml
index 6b9adaa483..7255aaa0ec 100644
--- a/src/lib/reasoners/bitlist.ml
+++ b/src/lib/reasoners/bitlist.ml
@@ -29,140 +29,152 @@ module Ex = Explanation
 
 exception Inconsistent of Ex.t
 
-(** A bitlist representing the known bits of a bit-vector of width [width].
+(** A bitlist representing the known bits of an infinite-width bit-vector.
+    Negative numbers are represented in two's complement.
 
     Active bits in [bits_set] are necessarily equal to [1].
-    Active bits in [bits_clr] are necessarily equal to [0].
+    Active bits in [bits_unk] are not known and may be either [0] or [1].
+    Bits that are active in neither [bits_set] nor [bits_unk] are equal to [0].
+
+    The sign is known (and equal to the sign of [bits_set]) if [bits_unk] is
+    positive, and the sign is unknown if [bits_unk] is negative.
 
     The explanation [ex] justifies that the bitlist applies. *)
-type t = { width: int ; bits_set : Z.t ; bits_clr : Z.t ; ex : Ex.t }
+type t = { bits_set : Z.t ; bits_unk : Z.t ; ex : Ex.t }
+
+let constant n ex =
+  { bits_set = n ; bits_unk = Z.zero ; ex }
 
-let unknown width ex =
-  { width ; bits_set = Z.zero ; bits_clr = Z.zero ; ex }
+let zero ex = constant Z.zero ex
 
-let empty =
-  { width = 0 ; bits_set = Z.zero ; bits_clr = Z.zero ; ex = Ex.empty }
+let empty = zero Ex.empty
 
-let width { width; _ } = width
+let unknown = { bits_set = Z.zero ; bits_unk = Z.minus_one ; ex = Ex.empty }
 
 let explanation { ex; _ } = ex
 
-let exact width value ex =
-  { width
-  ; bits_set = Z.extract value 0 width
-  ; bits_clr = Z.extract (Z.lognot value) 0 width
+let exact value ex =
+  { bits_set = value
+  ; bits_unk = Z.zero
   ; ex }
 
 let equal b1 b2 =
-  b1.width = b2.width &&
   Z.equal b1.bits_set b2.bits_set &&
-  Z.equal b1.bits_clr b2.bits_clr
+  Z.equal b1.bits_unk b2.bits_unk
 
-let ones b = { b with bits_clr = Z.zero }
+let ones b = { b with bits_unk = Z.(b.bits_unk lor ~!(b.bits_set)) }
 
-let zeroes b = { b with bits_set = Z.zero }
+let zeroes b =
+  { b with bits_set = Z.zero ; bits_unk = Z.logor b.bits_unk b.bits_set }
 
 let add_explanation ~ex b = { b with ex = Ex.union b.ex ex }
 
-let pp ppf { width; bits_set; bits_clr; ex } =
+let pp ppf { bits_set; bits_unk; ex } =
+  begin if Z.sign bits_unk < 0 then
+      (* Sign is not known *)
+      Fmt.pf ppf "(?)"
+    else if Z.sign bits_set < 0 then
+      Fmt.pf ppf "(1)"
+    else
+      Fmt.pf ppf "(0)"
+  end;
+  let width = max (Z.numbits bits_set) (Z.numbits bits_unk) in
   for i = width - 1 downto 0 do
     if Z.testbit bits_set i then
       Fmt.pf ppf "1"
-    else if Z.testbit bits_clr i then
-      Fmt.pf ppf "0"
-    else
+    else if Z.testbit bits_unk i then
       Fmt.pf ppf "?"
+    else
+      Fmt.pf ppf "0"
   done;
   if Options.(get_verbose () || get_unsat_core ()) then
     Fmt.pf ppf " %a" (Fmt.box Ex.print) ex
 
-let bitlist ~width ~bits_set ~bits_clr ex =
-  if not (Z.equal (Z.logand bits_set bits_clr) Z.zero) then
-    raise @@ Inconsistent ex;
-
-  { width; bits_set; bits_clr ; ex }
-
-let bits_known b = Z.logor b.bits_set b.bits_clr
-
-let num_unknown b = b.width - Z.popcount (bits_known b)
+let unknown_bits b = b.bits_unk
 
 let value b = b.bits_set
 
-let is_fully_known b =
-  Z.(equal (shift_right (bits_known b + ~$1) b.width) ~$1)
+let is_fully_known b = Z.equal b.bits_unk Z.zero
 
 let intersect b1 b2 =
-  let width = b1.width in
   let bits_set = Z.logor b1.bits_set b2.bits_set in
-  let bits_clr = Z.logor b1.bits_clr b2.bits_clr in
-  bitlist ~width ~bits_set ~bits_clr
-    (Ex.union b1.ex b2.ex)
+  let bits_unk = Z.logand b1.bits_unk b2.bits_unk in
+  (* If there is a bit that is known in both bitlists with different values,
+     the intersection is empty. *)
+  let distinct = Z.logxor b1.bits_set b2.bits_set in
+  let known = Z.lognot (Z.logor b1.bits_unk b2.bits_unk) in
+  if not (Z.equal (Z.logand distinct known) Z.zero) then
+    raise @@ Inconsistent (Ex.union b1.ex b2.ex);
+
+  { bits_set ; bits_unk ; ex = Ex.union b1.ex b2.ex }
+
+let extract b ofs len =
+  if len = 0 then empty
+  else
+    (* Always consistent *)
+    { bits_set = Z.extract b.bits_set ofs len
+    ; bits_unk = Z.extract b.bits_unk ofs len
+    ; ex = b.ex
+    }
 
-let concat b1 b2 =
-  let bits_set = Z.(logor (b1.bits_set lsl b2.width) b2.bits_set)
-  and bits_clr = Z.(logor (b1.bits_clr lsl b2.width) b2.bits_clr)
-  in
+let lognot b =
   (* Always consistent *)
-  { width = b1.width + b2.width
-  ; bits_set
-  ; bits_clr
-  ; ex = Ex.union b1.ex b2.ex
-  }
+  { b with bits_set = Z.(~!(b.bits_set lor b.bits_unk))}
 
-let ( @ ) = concat
+let ( ~! ) = lognot
 
-let extract b i j =
+let logor b1 b2 =
+  (* A bit is set in [b1 | b2] if it is set in either [b1] or [b2]. *)
+  let bits_set = Z.logor b1.bits_set b2.bits_set in
+  (* A bit is unknown in [b1 | b2] if it is unknown in either [b1] or [b2],
+     unless is set to [1] in either [b1] or [b2]. *)
+  let bits_unk =
+    Z.logand (Z.logor b1.bits_unk b2.bits_unk)
+      (Z.lognot bits_set)
+  in
   (* Always consistent *)
-  { width = j - i + 1
-  ; bits_set = Z.extract b.bits_set i (j - i + 1)
-  ; bits_clr = Z.extract b.bits_clr i (j - i + 1)
-  ; ex = b.ex
+  { bits_set
+  ; bits_unk
+  ; ex = Ex.union b1.ex b2.ex
   }
 
-let lognot b =
-  (* Always consistent *)
-  { b with bits_set = b.bits_clr; bits_clr = b.bits_set }
+let ( lor ) = logor
 
 let logand b1 b2 =
-  let width = b1.width in
   let bits_set = Z.logand b1.bits_set b2.bits_set in
-  let bits_clr = Z.logor b1.bits_clr b2.bits_clr in
+  (* A bit is unknown in [b1 & b2] if it is unknown in both [b1] and [b2], or if
+     it is equal to [1] in either side and unknown in the other. *)
+  let bits_unk =
+    Z.logor (Z.logand b1.bits_set b2.bits_unk) @@
+    Z.logor (Z.logand b1.bits_unk b2.bits_set) @@
+    Z.logand b1.bits_unk b2.bits_unk
+  in
   (* Always consistent *)
-  { width
-  ; bits_set
-  ; bits_clr
+  { bits_set
+  ; bits_unk
   ; ex = Ex.union b1.ex b2.ex
   }
 
-let logor b1 b2 =
-  let width = b1.width in
-  let bits_set = Z.logor b1.bits_set b2.bits_set in
-  let bits_clr = Z.logand b1.bits_clr b2.bits_clr in
-  (* Always consistent *)
-  { width
-  ; bits_set
-  ; bits_clr
-  ; ex = Ex.union b1.ex b2.ex
-  }
+let ( land ) = logand
 
 let logxor b1 b2 =
-  let width = b1.width in
+  (* A bit is unknown in [b1 ^ b2] if it is unknown in either [b1] or [b2]. *)
+  let bits_unk = Z.logor b1.bits_unk b2.bits_unk in
+  (* Need to mask because [Z.logxor] will compute a bogus value for unknown
+     bits. *)
   let bits_set =
-    Z.logor
-      (Z.logand b1.bits_set b2.bits_clr)
-      (Z.logand b1.bits_clr b2.bits_set)
-  and bits_clr =
-    Z.logor
-      (Z.logand b1.bits_set b2.bits_set)
-      (Z.logand b1.bits_clr b2.bits_clr)
+    Z.logand
+      (Z.logxor b1.bits_set b2.bits_set)
+      (Z.lognot bits_unk)
   in
   (* Always consistent *)
-  { width
-  ; bits_set
-  ; bits_clr
+  { bits_set
+  ; bits_unk
   ; ex = Ex.union b1.ex b2.ex
   }
 
+let ( lxor ) = logxor
+
 (* The logic for the [increase_lower_bound] function below is described in
    section 4.1 of
 
@@ -176,9 +188,12 @@ let logxor b1 b2 =
 (* [left_cl_can_set highest_cleared cleared_can_set] returns the
    least-significant bit that is:
    - More significant than [highest_cleared], strictly;
-   - Set in [cleared_can_set] *)
+   - Set in [cleared_can_set]
+
+   Raises [Not_found] if there are no such bit. *)
 let left_cl_can_set highest_cleared cleared_can_set =
   let can_set = Z.(cleared_can_set asr highest_cleared) in
+  if Z.equal can_set Z.zero then raise Not_found;
   highest_cleared + Z.trailing_zeros can_set
 
 let increase_lower_bound b lb =
@@ -188,7 +203,7 @@ let increase_lower_bound b lb =
      [cleared_bits] contains the bits that were set in [lb] and got cleared in
      [r]; conversely, [set_bits] contains the bits that were cleared in [lb] and
      got set in [r]. *)
-  let r = Z.logor b.bits_set (Z.logand lb (Z.lognot b.bits_clr)) in
+  let r = Z.logor b.bits_set (Z.logand lb b.bits_unk) in
   let cleared_bits = Z.logand lb (Z.lognot r) in
   let set_bits = Z.logand (Z.lognot lb) r in
 
@@ -227,10 +242,10 @@ let increase_lower_bound b lb =
        as when the most-significant changed bit was 0 and is now 1 (see [if]
        case above). *)
     let bit_to_clear = Z.numbits cleared_bits in
-    let cleared_can_set = Z.lognot @@ Z.logor r b.bits_clr in
+    let cleared_can_set =
+      Z.logand (Z.lognot r) (Z.logor b.bits_set b.bits_unk)
+    in
     let bit_to_set = left_cl_can_set bit_to_clear cleared_can_set in
-    if bit_to_set >= b.width then
-      raise Not_found;
     let r = Z.logor r Z.(~$1 lsl bit_to_set) in
     let mask  = Z.(minus_one lsl bit_to_set) in
     Z.logand r @@ Z.logor mask b.bits_set
@@ -238,70 +253,81 @@ let increase_lower_bound b lb =
 
 let decrease_upper_bound b ub =
   (* x <= ub <-> ~ub <= ~x *)
-  let sz = width b in
-  assert (Z.numbits ub <= sz);
-  let nub =
-    increase_lower_bound (lognot b) (Z.extract (Z.lognot ub) 0 sz)
-  in
-  Z.extract (Z.lognot nub) 0 sz
+  Z.lognot @@ increase_lower_bound (lognot b) (Z.lognot ub)
 
 let fold_domain f b acc =
-  if b.width <= 0 then
+  (* If [bits_unk] is negative, the domain is infinite. *)
+  if Z.sign b.bits_unk < 0 then
     invalid_arg "Bitlist.fold_domain";
+  let width = Z.numbits b.bits_unk in
   let rec fold_domain_aux ofs b acc =
-    if ofs >= b.width then (
+    if ofs >= width then (
       assert (is_fully_known b);
       f (value b) acc
-    ) else if Z.testbit b.bits_clr ofs || Z.testbit b.bits_set ofs then
+    ) else if not (Z.testbit b.bits_unk ofs) then
       fold_domain_aux (ofs + 1) b acc
     else
       let mask = Z.(one lsl ofs) in
+      let bits_unk = Z.logand b.bits_unk (Z.lognot mask) in
       let acc =
         fold_domain_aux
-          (ofs + 1) { b with bits_clr = Z.logor b.bits_clr mask } acc
+          (ofs + 1) { b with bits_unk } acc
       in
       fold_domain_aux
-        (ofs + 1) { b with bits_set = Z.logor b.bits_set mask } acc
+        (ofs + 1) { b with bits_unk; bits_set = Z.logor b.bits_set mask } acc
   in
   fold_domain_aux 0 b acc
 
+let shift_left a n =
+  { bits_set = Z.(a.bits_set lsl n)
+  ; bits_unk = Z.(a.bits_unk lsl n)
+  ; ex = a.ex
+  }
+
+let ( lsl ) = shift_left
+
+let shift_right a n =
+  { bits_set = Z.(a.bits_set asr n)
+  ; bits_unk = Z.(a.bits_unk asr n)
+  ; ex = a.ex
+  }
+
+let ( asr ) = shift_right
+
 (* simple propagator: only compute known low bits *)
 let mul a b =
-  let sz = width a in
-  assert (width b = sz);
-
   let ex = Ex.union (explanation a) (explanation b) in
 
   (* (a * 2^n) * (b * 2^m) = (a * b) * 2^(n + m) *)
-  let zeroes_a = Z.trailing_zeros @@ Z.lognot a.bits_clr in
-  let zeroes_b = Z.trailing_zeros @@ Z.lognot b.bits_clr in
-  if zeroes_a + zeroes_b >= sz then
-    exact sz Z.zero ex
+  let zeroes_a = Z.trailing_zeros @@ Z.logor a.bits_set a.bits_unk in
+  let zeroes_b = Z.trailing_zeros @@ Z.logor b.bits_set b.bits_unk in
+  if zeroes_a = max_int || zeroes_b = max_int then
+    zero ex
   else
-    let low_bits =
-      if zeroes_a + zeroes_b = 0 then empty
-      else exact (zeroes_a + zeroes_b) Z.zero ex
-    in
-    let a = extract a zeroes_a (zeroes_a + sz - width low_bits - 1) in
-    assert (width a + width low_bits = sz);
-    let b = extract b zeroes_b (zeroes_b + sz - width low_bits - 1) in
-    assert (width b + width low_bits = sz);
+    let a = a asr zeroes_a in
+    let b = b asr zeroes_b in
+    let zeroes = zeroes_a + zeroes_b in
     (* ((ah * 2^n) + al) * ((bh * 2^m) + bl) =
         al * bl  (mod 2^(min n m)) *)
-    let low_a_known = Z.trailing_zeros @@ Z.lognot @@ bits_known a in
-    let low_b_known = Z.trailing_zeros @@ Z.lognot @@ bits_known b in
+    let low_a_known = Z.trailing_zeros @@ a.bits_unk in
+    let low_b_known = Z.trailing_zeros @@ b.bits_unk in
     let low_known = min low_a_known low_b_known in
+    let mid_bits = exact Z.(value a * value b) ex in
     let mid_bits =
-      if low_known = 0 then empty
-      else exact
-          low_known
-          Z.(extract (value a) 0 low_known * extract (value b) 0 low_known)
-          ex
+      if low_known = max_int then mid_bits
+      else extract mid_bits 0 low_known
     in
-    concat (unknown (sz - width mid_bits - width low_bits) Ex.empty) @@
-    concat mid_bits low_bits
+    if low_known = max_int then
+      mid_bits lsl zeroes
+    else
+      let high_bits =
+        { bits_set = Z.zero
+        ; bits_unk = Z.minus_one
+        ; ex = Ex.empty }
+      in
+      ((high_bits lsl low_known) lor mid_bits) lsl zeroes
 
-let shl a b =
+let bvshl ~size:sz a b =
   (* If the minimum value for [b] is larger than the bitwidth, the result is
      zero.
 
@@ -312,19 +338,21 @@ let shl a b =
      NB: we would like to use the lower bound from the interval domain for [b]
      here. *)
   match Z.to_int (increase_lower_bound b Z.zero) with
-  | n when n < width a ->
-    let low_zeros = Z.trailing_zeros @@ Z.lognot @@ a.bits_clr in
-    if low_zeros + n >= width a then
-      exact (width a) Z.zero (Ex.union (explanation a) (explanation b))
+  | n when n < sz ->
+    let low_zeros = Z.trailing_zeros @@ Z.logor a.bits_set a.bits_unk in
+    if low_zeros + n >= sz then
+      extract (exact Z.zero (Ex.union (explanation a) (explanation b))) 0 sz
     else if low_zeros + n > 0 then
-      concat (unknown (width a - low_zeros - n) Ex.empty) @@
-      exact (low_zeros + n) Z.zero (Ex.union (explanation a) (explanation b))
+      ((extract unknown 0 (sz - low_zeros - n)) lsl (low_zeros + n)) lor
+      extract
+        (exact Z.zero (Ex.union (explanation a) (explanation b)))
+        0 (low_zeros + n)
     else
-      unknown (width a) Ex.empty
+      extract unknown 0 sz
   | _ | exception Z.Overflow ->
-    exact (width a) Z.zero (explanation b)
+    constant Z.zero (explanation b)
 
-let lshr a b =
+let bvlshr ~size:sz a b =
   (* If the minimum value for [b] is larger than the bitwidth, the result is
      zero.
 
@@ -335,24 +363,26 @@ let lshr a b =
      NB: we would like to use the lower bound from the interval domain for [b]
      here. *)
   match Z.to_int (increase_lower_bound b Z.zero) with
-  | n when n < width a ->
-    let sz = width a in
-    if Z.testbit a.bits_clr (sz - 1) then (* MSB is zero *)
-      let low_msb_zero = Z.numbits @@ Z.extract (Z.lognot a.bits_clr) 0 sz in
+  | n when n < sz ->
+    if not (Z.testbit a.bits_unk (sz - 1) || Z.testbit a.bits_set (sz - 1))
+    then (* MSB is zero *)
+      let low_msb_zero =
+        Z.numbits @@ Z.extract (Z.logor a.bits_set a.bits_unk) 0 sz
+      in
       let nb_msb_zeros = sz - low_msb_zero in
       assert (nb_msb_zeros > 0);
       let nb_zeros = nb_msb_zeros + n in
       if nb_zeros >= sz then
-        exact sz Z.zero (Ex.union (explanation a) (explanation b))
+        constant Z.zero (Ex.union (explanation a) (explanation b))
       else
-        concat
-          (exact nb_zeros Z.zero (Ex.union (explanation a) (explanation b)))
-          (unknown (sz - nb_zeros) Ex.empty)
+        (
+          constant Z.zero (Ex.union (explanation a) (explanation b))
+          lsl (sz - nb_zeros)
+        ) lor (extract unknown 0 (sz - nb_zeros))
     else if n > 0 then
-      concat
-        (exact n Z.zero (explanation b))
-        (unknown (sz - n) Ex.empty)
+      (constant Z.zero (explanation b) lsl (sz - n)) lor
+      extract unknown 0 (sz - n)
     else
-      unknown sz Ex.empty
+      extract unknown 0 sz
   | _ | exception Z.Overflow ->
-    exact (width a) Z.zero (explanation b)
+    constant Z.zero (explanation b)
diff --git a/src/lib/reasoners/bitlist.mli b/src/lib/reasoners/bitlist.mli
index 14a69c888a..5d7779dffc 100644
--- a/src/lib/reasoners/bitlist.mli
+++ b/src/lib/reasoners/bitlist.mli
@@ -28,7 +28,9 @@
 (** Bit-lists provide a domain on bit-vectors that represent the known bits
     sets to [1] and [0], respectively.
 
-    This module provides an implementation of bitlists and related operators.*)
+    This module provides an implementation of bitlists and related operators.
+    The bitlists provided by this module do not have a fixed width, and can
+    represent arbitrary-precision integers. *)
 
 type t
 (** The type of bitlists.
@@ -49,22 +51,15 @@ val pp : t Fmt.t
 exception Inconsistent of Explanation.t
 (** Exception raised when an inconsistency is detected. *)
 
-val unknown : int -> Explanation.t -> t
-(** [unknown w ex] returns an bitlist of width [w] with no known bits. *)
-
-val empty : t
-(** An empty bitlist of width 0 and no explanation. *)
-
-val width : t -> int
-(** Returns the width of the bitlist. *)
+val unknown : t
+(** [unknown] is a bitlist that repersents all integers. *)
 
 val explanation : t -> Explanation.t
 (** Returns the explanation associated with the bitlist. See the type-level
     documentation for details. *)
 
-val exact : int -> Z.t -> Explanation.t -> t
-(** [exact w v ex] returns a bitlist of width [w] that represents the [w]-bits
-    constant [v]. *)
+val exact : Z.t -> Explanation.t -> t
+(** [exact v ex] returns a bitlist that represents the constant [v]. *)
 
 val equal : t -> t -> bool
 (** [equal b1 b2] returns [true] if the bitlists [b1] and [b2] are equal, albeit
@@ -82,16 +77,15 @@ val add_explanation : ex:Explanation.t -> t -> t
 (** [add_explanation ~ex b] adds the explanation [ex] to the bitlist [b]. The
     returned bitlist has both the explanation of [b] and [ex] as explanation. *)
 
-val bits_known : t -> Z.t
-(** [bits_known b] returns the sets of bits known to be either [1] or [0] as a
-    bitmask. *)
+val unknown_bits : t -> Z.t
+(** [unknown_bits b] returns the set of unknown (or undetermined) bits in [b].
 
-val num_unknown : t -> int
-(** [num_unknown b] returns the number of bits unknown in [b]. *)
+    The value of [Z.logand (Z.lognot (unknown_bits b)) n] is the same for any
+    [n] in the set represented by the bitlist [b]. *)
 
 val is_fully_known : t -> bool
 (** [is_fully_known b] determines if there are unknown bits in [b] or not.
-    [is_fully_known b] is [true] iff [num_unknown b] is [0]. *)
+    [is_fully_known b] is [true] iff [unknown_bits b] is [Z.zero]. *)
 
 val value : t -> Z.t
 (** [value b] returns the value associated with the bitlist [b]. If the bitlist
@@ -110,33 +104,38 @@ val lognot : t -> t
 (** [lognot b] swaps the bits that are set and cleared. *)
 
 val logand : t -> t -> t
-(** Bitwise and. *)
+(** Bit-wise and. *)
 
 val logor : t -> t -> t
-(** Bitwise or. *)
+(** Bit-wise or. *)
 
 val logxor : t -> t -> t
-(** Bitwise xor. *)
+(** Bit-wise xor. *)
 
 val mul : t -> t -> t
 (** Multiplication. *)
 
-val shl : t -> t -> t
-(** Logical left shift. *)
+val bvshl : size:int -> t -> t -> t
+(** Logical left shift, truncated to the [size] least significant bits. *)
 
-val lshr : t -> t -> t
-(** Logical right shift. *)
+val bvlshr : size:int -> t -> t -> t
+(** Logical right shift, truncated to the [size] least significant bits. *)
 
-val concat : t -> t -> t
-(** Bit-vector concatenation. *)
+val shift_left : t -> int -> t
+(** Shifts to the left. Equivalent to a multiplication by a power of [2]. The
+    second argument must be nonnegative. *)
 
-val ( @ ) : t -> t -> t
-(** Alias for [concat]. *)
+val shift_right : t -> int -> t
+(** Shifts to the right. This is an arithmetic shift, equivalent to a division
+    by a power of [2] with rounding towards -oo. The second argument must be
+    nonnegative. *)
 
 val extract : t -> int -> int -> t
-(** [extract b i j] returns the bitlist from index [i] to index [j] inclusive.
+(** [extract b off len] returns a nonnegative bitlist corresponding to bits
+    [off] to [off + len - 1] of [b].
 
-    The resulting bitlist has length [j - i + 1]. *)
+    {b Note}: This uses the same arguments as [Z.extract], not the arguments
+    from the SMT-LIB's [extract] primitive. *)
 
 val increase_lower_bound : t -> Z.t -> Z.t
 (** [increase_lower_bound b lb] returns the smallest integer [lb' >= lb] that
@@ -150,6 +149,26 @@ val decrease_upper_bound : t -> Z.t -> Z.t
 
     @raise Not_found if no such integer exists. *)
 
+(** {2 Prefix and infix operators} *)
+
+val ( land ) : t -> t -> t
+(** Bit-wise logical and [logand]. *)
+
+val ( lor ) : t -> t -> t
+(** Bit-wise logical inclusive or [logor]. *)
+
+val ( lxor ) : t -> t -> t
+(** Bit-wise logical exclusive xor [logxor]. *)
+
+val ( ~! ) : t -> t
+(** Bit-wise logical negation [lognot]. *)
+
+val ( lsl ) : t -> int -> t
+(** Bit-wise shift to the left [shift_left]. *)
+
+val ( asr ) : t -> int -> t
+(** Bit-wise shift to the right [shift_right]. *)
+
 (**/**)
 
 (** [fold_finite_domain f i acc] accumulates [f] on all the elements of [i] (in
diff --git a/src/lib/reasoners/bitv_rel.ml b/src/lib/reasoners/bitv_rel.ml
index 1a90250d67..59f949ea63 100644
--- a/src/lib/reasoners/bitv_rel.ml
+++ b/src/lib/reasoners/bitv_rel.ml
@@ -75,6 +75,9 @@ let is_bv_ty = function
 
 let is_bv_r r = is_bv_ty @@ X.type_info r
 
+let bitwidth r =
+  match X.type_info r with Tbitv n -> n | _ -> assert false
+
 module Interval_domain = struct
   type t = Intervals.Int.t
 
@@ -112,11 +115,7 @@ module Interval_domain = struct
     Intervals.Int.of_bounds ?ex (Closed n) (Closed n)
 
   let fold_leaves f r int acc =
-    let width =
-      match X.type_info r with
-      | Tbitv n -> n
-      | _ -> assert false
-    in
+    let width = bitwidth r in
     let j, acc =
       List.fold_left (fun (j, acc) { Bitv.bv; sz } ->
           (* sz = j - i + 1 => i = j - sz + 1 *)
@@ -170,49 +169,53 @@ module Bitlist_domain : Rel_utils.Domain with type t = Bitlist.t = struct
 
   let filter_ty = is_bv_ty
 
-  let fold_signed f { Bitv.value; negated } bl acc =
-    let bl = if negated then lognot bl else bl in
+  let fold_signed sz f { Bitv.value; negated } bl acc =
+    let bl = if negated then extract (lognot bl) 0 sz else bl in
     f value bl acc
 
   let fold_leaves f r bl acc =
-    fst @@ List.fold_left (fun (acc, bl) { Bitv.bv; sz } ->
+    let sz = bitwidth r in
+    let (acc, _, _) = List.fold_left (fun (acc, bl, w) { Bitv.bv; sz } ->
         (* Extract the bitlist associated with the current component *)
-        let mid = width bl - sz in
-        let bl_tail =
-          if mid = 0 then empty else
-            extract bl 0 (mid - 1)
-        in
-        let bl = extract bl mid (width bl - 1) in
+        let mid = w - sz in
+        let bl_tail = extract bl 0 mid in
+        let bl = extract bl mid (w - mid) in
 
         match bv with
         | Bitv.Cte z ->
+          assert (Z.numbits z <= sz);
           (* Nothing to update, but still check for consistency! *)
-          ignore @@ intersect bl (exact sz z Ex.empty);
-          acc, bl_tail
-        | Other r -> fold_signed f r bl acc, bl_tail
+          ignore @@ intersect bl (exact z Ex.empty);
+          acc, bl_tail, mid
+        | Other r -> fold_signed sz f r bl acc, bl_tail, mid
         | Ext (r, r_size, i, j) ->
           (* r<i, j> = bl -> r = ?^(r_size - j - 1) @ bl @ ?^i *)
-          assert (i + r_size - j - 1 + width bl = r_size);
-          let hi = Bitlist.unknown (r_size - j - 1) Ex.empty in
-          let lo = Bitlist.unknown i Ex.empty in
-          fold_signed f r (hi @ bl @ lo) acc, bl_tail
-      ) (acc, bl) (Shostak.Bitv.embed r)
-
-  let map_signed f { Bitv.value; negated } =
+          assert (i + r_size - j - 1 + w - mid = r_size);
+          let hi = Bitlist.(extract unknown 0 (r_size - j - 1)) in
+          let lo = Bitlist.(extract unknown 0 i) in
+          let bl_hd = Bitlist.((hi lsl (j + 1)) lor (bl lsl i) lor lo) in
+          fold_signed r_size f r bl_hd acc,
+          bl_tail,
+          mid
+      ) (acc, bl, sz) (Shostak.Bitv.embed r)
+    in acc
+
+  let map_signed sz f { Bitv.value; negated } =
     let bl = f value in
-    if negated then lognot bl else bl
+    if negated then extract (lognot bl) 0 sz else bl
 
   let map_leaves f r =
     List.fold_left (fun bl { Bitv.bv; sz } ->
-        concat bl @@
+        bl lsl sz lor
         match bv with
-        | Bitv.Cte z -> exact sz z Ex.empty
-        | Other r -> map_signed f r
-        | Ext (r, _r_size, i, j) -> extract (map_signed f r) i j
-      ) empty (Shostak.Bitv.embed r)
+        | Bitv.Cte z -> extract (exact z Ex.empty) 0 sz
+        | Other r -> map_signed sz f r
+        | Ext (r, r_sz, i, j) ->
+          extract (map_signed r_sz f r) i (j - i + 1)
+      ) (exact Z.zero Ex.empty) (Shostak.Bitv.embed r)
 
   let unknown = function
-    | Ty.Tbitv n -> unknown n Ex.empty
+    | Ty.Tbitv n -> extract unknown 0 n
     | _ ->
       (* Only bit-vector values can have bitlist domains. *)
       invalid_arg "unknown"
@@ -327,46 +330,51 @@ end = struct
     | Band | Bor | Bxor | Badd | Bmul -> true
     | Budiv | Burem | Bshl | Blshr -> false
 
-  let propagate_binop ~ex dx op dy dz =
+  let propagate_binop ~ex sz dx op dy dz =
     let open Domains.Ephemeral in
+    let norm bl = Bitlist.extract bl 0 sz in
     match op with
     | Band ->
-      update ~ex dx (Bitlist.logand !!dy !!dz);
+      update ~ex dx @@ norm @@ Bitlist.logand !!dy !!dz;
       (* Reverse propagation for y: if [x = y & z] then:
          - Any [1] in [x] must be a [1] in [y]
          - Any [0] in [x] that is also a [1] in [z] must be a [0] in [y]
       *)
-      update ~ex dy (Bitlist.ones !!dx);
-      update ~ex dy Bitlist.(logor (zeroes !!dx) (lognot (ones !!dz)));
-      update ~ex dz (Bitlist.ones !!dx);
-      update ~ex dz Bitlist.(logor (zeroes !!dx) (lognot (ones !!dy)))
+      update ~ex dy @@ norm @@ Bitlist.ones !!dx;
+      update ~ex dy @@ norm @@
+      Bitlist.(logor (zeroes !!dx) (lognot (ones !!dz)));
+      update ~ex dz @@ norm @@ Bitlist.ones !!dx;
+      update ~ex dz @@ norm @@
+      Bitlist.(logor (zeroes !!dx) (lognot (ones !!dy)))
     | Bor ->
-      update ~ex dx (Bitlist.logor !!dy !!dz);
+      update ~ex dx @@ norm @@ Bitlist.logor !!dy !!dz;
       (* Reverse propagation for y: if [x = y | z] then:
          - Any [0] in [x] must be a [0] in [y]
          - Any [1] in [x] that is also a [0] in [z] must be a [1] in [y]
       *)
-      update ~ex dy (Bitlist.zeroes !!dx);
-      update ~ex dy Bitlist.(logand (ones !!dx) (lognot (zeroes !!dz)));
-      update ~ex dz (Bitlist.zeroes !!dx);
-      update ~ex dz Bitlist.(logand (ones !!dx) (lognot (zeroes !!dy)))
+      update ~ex dy @@ norm @@ Bitlist.zeroes !!dx;
+      update ~ex dy @@ norm @@
+      Bitlist.(logand (ones !!dx) (lognot (zeroes !!dz)));
+      update ~ex dz @@ norm @@ Bitlist.zeroes !!dx;
+      update ~ex dz @@ norm @@
+      Bitlist.(logand (ones !!dx) (lognot (zeroes !!dy)))
     | Bxor ->
-      update ~ex dx (Bitlist.logxor !!dy !!dz);
+      update ~ex dx @@ norm @@ Bitlist.logxor !!dy !!dz;
       (* x = y ^ z <-> y = x ^ z *)
-      update ~ex dy (Bitlist.logxor !!dx !!dz);
-      update ~ex dz (Bitlist.logxor !!dx !!dy)
+      update ~ex dy @@ norm @@ Bitlist.logxor !!dx !!dz;
+      update ~ex dz @@ norm @@ Bitlist.logxor !!dx !!dy
     | Badd ->
       (* TODO: full adder propagation *)
       ()
 
     | Bshl -> (* Only forward propagation for now *)
-      update ~ex dx (Bitlist.shl !!dy !!dz)
+      update ~ex dx (Bitlist.bvshl ~size:sz !!dy !!dz)
 
     | Blshr -> (* Only forward propagation for now *)
-      update ~ex dx (Bitlist.lshr !!dy !!dz)
+      update ~ex dx (Bitlist.bvlshr ~size:sz !!dy !!dz)
 
     | Bmul -> (* Only forward propagation for now *)
-      update ~ex dx (Bitlist.mul !!dy !!dz)
+      update ~ex dx @@ norm @@ Bitlist.mul !!dy !!dz
 
     | Budiv -> (* No bitlist propagation for now *)
       ()
@@ -434,13 +442,14 @@ end = struct
     let get r = handle dom r in
     match f with
     | Fbinop (op, x, y) ->
-      propagate_binop ~ex (get r) op (get x) (get y)
+      let n = bitwidth r in
+      propagate_binop ~ex n (get r) op (get x) (get y)
 
   let propagate_interval_fun_t ~ex dom r f =
     let get r = Interval_domains.Ephemeral.handle dom r in
     match f with
     | Fbinop (op, x, y) ->
-      let sz = match X.type_info r with Tbitv n -> n | _ -> assert false in
+      let sz = bitwidth r in
       propagate_interval_binop ~ex sz (get r) op (get x) (get y)
 
   type binrel = Rule | Rugt
@@ -629,9 +638,6 @@ end = struct
   let propagate_interval ~ex c dom =
     propagate_interval_repr ~ex dom c.repr
 
-  let bitwidth r =
-    match X.type_info r with Tbitv n -> n | _ -> assert false
-
   let const sz n =
     Shostak.Bitv.is_mine [ { bv = Cte (Z.extract n 0 sz); sz } ]
 
@@ -979,15 +985,15 @@ let rec mk_eq ex lhs w z =
     applies to [r], exposes the equality [r = bl] as a list of Xliteral values
     (accumulated into [acc]) so that the union-find learns about the equality *)
 let add_eqs =
-  let rec aux x x_sz acc bl =
-    let known = Bitlist.bits_known bl in
-    let width = Bitlist.width bl in
+  let rec aux x x_sz acc width bl =
+    let known = Z.lognot (Bitlist.unknown_bits bl) in
+    let known = Z.extract known 0 width in
     let nbits = Z.numbits known in
     assert (nbits <= width);
     if nbits = 0 then
       acc
     else if nbits < width then
-      aux x x_sz acc (Bitlist.extract bl 0 (nbits - 1))
+      aux x x_sz acc nbits (Bitlist.extract bl 0 nbits)
     else
       let nbits = Z.numbits (Z.extract (Z.lognot known) 0 width) in
       let v = Z.extract (Bitlist.value bl) nbits (width - nbits) in
@@ -997,10 +1003,10 @@ let add_eqs =
       if nbits = 0 then
         lits @ acc
       else
-        aux x x_sz (lits @ acc) (Bitlist.extract bl 0 (nbits - 1))
+        aux x x_sz (lits @ acc) nbits (Bitlist.extract bl 0 nbits)
   in
-  fun acc x bl ->
-    aux x (Bitlist.width bl) acc bl
+  fun acc x x_sz bl ->
+    aux x x_sz acc x_sz bl
 
 module Any_constraint = struct
   type t =
@@ -1030,32 +1036,6 @@ end
 
 module QC = Uqueue.Make(Any_constraint)
 
-(* Compute the number of most significant bits shared by [inf] and [sup].
-
-   Requires: [inf <= sup]
-   Ensures:
-    result is the greatest integer <= sz such that
-    inf<sz - result .. sz> = sup<sz - result .. sz>
-
-    In particular, [result = sz] iff [inf = sup] and [result = 0] iff the most
-    significant bits of [inf] and [sup] differ. *)
-let rec shared_msb sz inf sup =
-  let numbits_inf = Z.numbits inf in
-  let numbits_sup = Z.numbits sup in
-  assert (numbits_inf <= numbits_sup);
-  if numbits_inf = numbits_sup then
-    (* Top [sz - numbits_inf] bits are 0 in both; look at 1s *)
-    if numbits_inf = 0 then
-      sz
-    else
-      sz - numbits_inf +
-      shared_msb numbits_inf
-        (Z.extract (Z.lognot sup) 0 numbits_inf)
-        (Z.extract (Z.lognot inf) 0 numbits_inf)
-  else
-    (* Top [sz - numbits_sup] are 0 in both, the next significant bit differs *)
-    sz - numbits_sup
-
 let finite_lower_bound = function
   | Intervals_intf.Unbounded -> Z.zero
   | Closed n -> n
@@ -1082,23 +1062,24 @@ let finite_upper_bound ~size:sz = function
    For example, m = 48 and M = 52 (00110000 and 00110100 in binary) share their
    five most-significant bits, denoted [00110???]. Therefore, a bit-vector bl =
    [0??1???0] can be refined into [00110??0]. *)
-let constrain_bitlist_from_interval bv int =
+let constrain_bitlist_from_interval ~size:sz bv int =
   let open Domains.Ephemeral in
-  let sz = Bitlist.width !!bv in
 
   let inf, inf_ex = Intervals.Int.lower_bound int in
   let inf = finite_lower_bound inf in
   let sup, sup_ex = Intervals.Int.upper_bound int in
   let sup = finite_upper_bound ~size:sz sup in
 
-  let nshared = shared_msb sz inf sup in
-  if nshared > 0 then
-    let ex = Ex.union inf_ex sup_ex in
-    let shared_bl =
-      Bitlist.exact nshared (Z.extract inf (sz - nshared) nshared) ex
-    in
-    update ~ex bv @@
-    Bitlist.concat shared_bl (Bitlist.unknown (sz - nshared) Ex.empty)
+  let distinct = Z.logxor inf sup in
+  (* If [distinct] is negative, [inf] and [sup] have different signs and there
+     are no significant shared bits. *)
+  if Z.sign distinct >= 0 then
+    let ofs = Z.numbits distinct in
+    update ~ex:Ex.empty bv @@
+    Bitlist.(
+      exact Z.((inf asr ofs) lsl ofs) (Ex.union inf_ex sup_ex) lor
+      extract unknown 0 ofs
+    )
 
 (* Algorithm 1 from
 
@@ -1112,7 +1093,7 @@ let constrain_bitlist_from_interval bv int =
    This function is a wrapper calling [Bitlist.increase_lower_bound] and
    [Bitlist.decrease_upper_bound] on all the constituent intervals of an union;
    see the documentation of these functions for details. *)
-let constrain_interval_from_bitlist int bv =
+let constrain_interval_from_bitlist ~size:sz int bv =
   let open Interval_domains.Ephemeral in
   let ex = Bitlist.explanation bv in
   (* Handy wrapper around [of_complement] *)
@@ -1129,7 +1110,7 @@ let constrain_interval_from_bitlist int bv =
   Intervals.Int.fold (fun acc i ->
       let { Intervals_intf.lb ; ub } = Intervals.Int.Interval.view i in
       let lb = finite_lower_bound lb in
-      let ub = finite_upper_bound ~size:(Bitlist.width bv) ub in
+      let ub = finite_upper_bound ~size:sz ub in
       let acc =
         match Bitlist.increase_lower_bound bv lb with
         | new_lb when Z.compare new_lb lb > 0 ->
@@ -1227,7 +1208,8 @@ let propagate_all eqs bcs bdom idom =
     touch_pending queue;
     HX.iter (fun r () ->
         HX.replace bitlist_changed r ();
-        constrain_interval_from_bitlist
+        let sz = bitwidth r in
+        constrain_interval_from_bitlist ~size:sz
           Interval_domains.Ephemeral.(handle idom r)
           Domains.Ephemeral.(!!(handle bdom r))
       ) touched;
@@ -1241,7 +1223,8 @@ let propagate_all eqs bcs bdom idom =
     let bcs = Constraints.clear_pending bcs in
     while HX.length touched > 0 do
       HX.iter (fun r () ->
-          constrain_bitlist_from_interval
+          let sz = bitwidth r in
+          constrain_bitlist_from_interval ~size:sz
             Domains.Ephemeral.(handle bdom r)
             Interval_domains.Ephemeral.(!!(handle idom r))
         ) touched;
@@ -1250,8 +1233,9 @@ let propagate_all eqs bcs bdom idom =
       assert (QC.is_empty queue);
 
       HX.iter (fun r () ->
+          let sz = bitwidth r in
           HX.replace bitlist_changed r ();
-          constrain_interval_from_bitlist
+          constrain_interval_from_bitlist ~size:sz
             Interval_domains.Ephemeral.(handle idom r)
             Domains.Ephemeral.(!!(handle bdom r))
         ) touched;
@@ -1263,7 +1247,8 @@ let propagate_all eqs bcs bdom idom =
     let eqs =
       HX.fold (fun r () acc ->
           let d = Domains.Ephemeral.(!!(handle bdom r)) in
-          add_eqs acc (Shostak.Bitv.embed r) d
+          let sz = bitwidth r in
+          add_eqs acc (Shostak.Bitv.embed r) sz d
         ) bitlist_changed eqs
     in
 
@@ -1379,7 +1364,7 @@ let case_split env uf ~for_model =
 
        [nunk] is the number of unknown bits. *)
     let f_acc r bl acc =
-      let nunk = Bitlist.num_unknown bl in
+      let nunk = Z.popcount (Bitlist.unknown_bits bl) in
       if nunk = 0 then
         acc
       else
@@ -1411,8 +1396,8 @@ let case_split env uf ~for_model =
     match SX.choose candidates with
     | r ->
       let bl = Domains.get r domain in
-      let w = Bitlist.width bl in
-      let unknown = Z.extract (Z.lognot @@ Bitlist.bits_known bl) 0 w in
+      let w = bitwidth r in
+      let unknown = Z.extract (Bitlist.unknown_bits bl) 0 w in
       let bitidx = Z.numbits unknown  - 1 in
       let lhs =
         Shostak.Bitv.is_mine @@
@@ -1451,5 +1436,6 @@ let assume_th_elt t th_elt _ =
   | _ -> t
 
 module Test = struct
-  let shared_msb = shared_msb
+  let shared_msb sz inf sup =
+    sz - Z.numbits (Z.logxor inf sup)
 end
diff --git a/tests/bitvec_tests.ml b/tests/bitvec_tests.ml
index 723cd37008..b6f95a58d9 100644
--- a/tests/bitvec_tests.ml
+++ b/tests/bitvec_tests.ml
@@ -1,6 +1,30 @@
 open AltErgoLib
 open QCheck2
 
+module type FixedSizeBitVector = sig
+  type t = Bitlist.t
+
+  val shl : t -> t -> t
+
+  val lshr : t -> t -> t
+
+  val mul : t -> t -> t
+end
+
+let fixed_size_bit_vector n : (module FixedSizeBitVector) =
+  let open Bitlist in
+  let norm b = extract b 0 n in
+  let binop op x y = norm (op x y) in
+  (module struct
+    type t = Bitlist.t
+
+    let shl a b = bvshl ~size:n a b
+
+    let lshr a b = bvlshr ~size:n a b
+
+    let mul = binop mul
+  end)
+
 module IntSet : sig
   type t
 
@@ -134,10 +158,16 @@ let bitlist sz =
   in
   let* (set_bits, clr_bits) = bitlist sz in
   let set_bits =
-    Bitlist.ones @@ Bitlist.exact sz set_bits Explanation.empty
+    Bitlist.extract (
+      Bitlist.ones @@
+      Bitlist.exact set_bits Explanation.empty
+    ) 0 sz
   in
   let clr_bits =
-    Bitlist.zeroes @@ Bitlist.exact sz (Z.lognot clr_bits) Explanation.empty
+    Bitlist.extract (
+      Bitlist.zeroes @@
+      Bitlist.exact (Z.extract (Z.lognot clr_bits) 0 sz) Explanation.empty
+    ) 0 sz
   in
   return @@ Bitlist.intersect set_bits clr_bits
 
@@ -247,9 +277,7 @@ let test_bitlist_binop ~count sz zop bop =
           (Fmt.to_to_string Bitlist.pp))
     Gen.(pair (bitlist sz) (bitlist sz))
     (fun (s, t) ->
-       let u = bop s t in
-       Bitlist.width u = Bitlist.width s &&
-       Bitlist.width u = Bitlist.width t &&
+       let u = bop (fixed_size_bit_vector sz) s t in
        IntSet.subset
          (IntSet.map2 zop (of_bitlist s) (of_bitlist t))
          (of_bitlist u))
@@ -271,7 +299,7 @@ let zmul sz a b =
 
 let test_bitlist_mul sz =
   test_bitlist_binop ~count:1_000
-    sz (zmul sz) Bitlist.mul
+    sz (zmul sz) (fun (module BV) -> BV.mul)
 
 let () =
   Test.check_exn (test_bitlist_mul 3)
@@ -290,7 +318,7 @@ let () =
 
 let test_bitlist_shl sz =
   test_bitlist_binop ~count:1_000
-    sz (zshl sz) Bitlist.shl
+    sz (zshl sz) (fun (module BV) -> BV.shl)
 
 let () =
   Test.check_exn (test_bitlist_shl 3)
@@ -309,7 +337,7 @@ let () =
 
 let test_bitlist_lshr sz =
   test_bitlist_binop ~count:1_000
-    sz zlshr Bitlist.lshr
+    sz zlshr (fun (module BV) -> BV.lshr)
 
 let () =
   Test.check_exn (test_bitlist_lshr 3)