diff --git a/src/engine.ml b/src/engine.ml index 3735be20..496d4b01 100644 --- a/src/engine.ml +++ b/src/engine.ml @@ -1031,11 +1031,11 @@ let pp_error ppf = function actual | `Msg msg -> Fmt.string ppf msg -let unpad block_size cs = - let l = String.length cs in - let amount = String.get_uint8 cs (pred l) in +let unpad block_size cs off = + let l = String.length cs - off in + let amount = String.get_uint8 cs (off + pred l) in let len = l - amount in - if len >= 0 && amount <= block_size then Ok (String.sub cs 0 len) + if len >= 0 && amount <= block_size then Ok (String.sub cs off len) else Error (`Msg "bad padding") let out ?add_timestamp prefix_len (ctx : keys) hmac_algorithm compress rng data @@ -1288,20 +1288,17 @@ let incoming_data ?(add_timestamp = false) err (ctx : keys) hmac_algorithm *) let open Mirage_crypto in let module H = (val Digestif.module_of_hash' hmac_algorithm) in - let hmac, data = - ( String.sub data 0 H.digest_size, - String.sub data H.digest_size (String.length data - H.digest_size) - ) - in + let hmac, off = (String.sub data 0 H.digest_size, H.digest_size) in let computed_hmac = - H.(to_raw_string (hmac_string ~key:their_hmac data)) + H.(to_raw_string (hmac_string ~off ~key:their_hmac data)) in let* () = guard (String.equal hmac computed_hmac) (err computed_hmac) in - let iv, data = - ( String.sub data 0 Cipher_block.AES.CBC.block_size, - String.sub data Cipher_block.AES.CBC.block_size - (String.length data - Cipher_block.AES.CBC.block_size) ) + let iv, off = + ( String.sub data off Cipher_block.AES.CBC.block_size, + off + Cipher_block.AES.CBC.block_size ) in + (* TODO: decrypt could take an offset and length to avoid copying *) + let data = String.sub data off (String.length data - off) in let dec = Cipher_block.AES.CBC.decrypt ~key:their_key ~iv data in (* dec is: uint32 replay packet id followed by (lzo-compressed) data and padding *) let hdr_len = Packet.id_len + if add_timestamp then 4 else 0 in @@ -1314,8 +1311,7 @@ let incoming_data ?(add_timestamp = false) err (ctx : keys) hmac_algorithm Log.debug (fun m -> m "received replay packet id is %lu" (String.get_int32_le dec 0)); (* TODO validate ts if provided (avoid replay) *) - unpad Cipher_block.AES.CBC.block_size - (String.sub dec hdr_len (String.length dec - hdr_len)) + unpad Cipher_block.AES.CBC.block_size dec hdr_len | AES_GCM { their_key; their_implicit_iv; _ } -> let tag_len = Mirage_crypto.Cipher_block.AES.GCM.tag_size in let* () =