diff --git a/.ocamlformat b/.ocamlformat new file mode 100644 index 0000000..71d5f8a --- /dev/null +++ b/.ocamlformat @@ -0,0 +1,8 @@ +version = 0.14.1 +break-infix = fit-or-vertical +parse-docstrings = true +indicate-multiline-delimiters=no +nested-match=align +sequence-style=separator +break-before-in=auto +if-then-else=keyword-first diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 9c902a4..0000000 --- a/.travis.yml +++ /dev/null @@ -1,19 +0,0 @@ -language: c -install: wget https://raw.githubusercontent.com/ocaml/ocaml-travisci-skeleton/master/.travis-docker.sh -script: bash -ex .travis-docker.sh -services: - - docker -sudo: false -env: - global: - - PACKAGE="base64" - - PRE_INSTALL_HOOK="cd /home/opam/opam-repository && git pull origin master && opam update -u -y" - matrix: - - DISTRO=debian-stable OCAML_VERSION=4.05 - - DISTRO=alpine OCAML_VERSION=4.06 - - DISTRO=ubuntu-16.04 OCAML_VERSION=4.07 -# - DISTRO=ubuntu-12.04 OCAML_VERSION=4.01.0 -# - DISTRO=ubuntu-16.04 OCAML_VERSION=4.03.0 -# - DISTRO=centos-6 OCAML_VERSION=4.02.3 -# - DISTRO=centos-7 OCAML_VERSION=4.03.0 -# - DISTRO=fedora-24 OCAML_VERSION=4.02.3 diff --git a/bench/benchmarks.ml b/bench/benchmarks.ml index 05b02e9..f0133b1 100644 --- a/bench/benchmarks.ml +++ b/bench/benchmarks.ml @@ -15,9 +15,9 @@ module Old_version = struct let decode ?alphabet input = let length = String.length input in let input = - if length mod 4 = 0 then input - else input ^ String.make (4 - (length mod 4)) padding - in + if length mod 4 = 0 + then input + else input ^ String.make (4 - (length mod 4)) padding in let length = String.length input in let words = length / 4 in let padding = @@ -25,8 +25,7 @@ module Old_version = struct | 0 -> 0 | _ when input.[length - 2] = padding -> 2 | _ when input.[length - 1] = padding -> 1 - | _ -> 0 - in + | _ -> 0 in let output = Bytes.make ((words * 3) - padding) '\000' in for i = 0 to words - 1 do let a = of_char ?alphabet input.[(4 * i) + 0] @@ -38,10 +37,10 @@ module Old_version = struct and y = (n lsr 8) land 255 and z = n land 255 in Bytes.set output ((3 * i) + 0) (char_of_int x) ; - if i <> words - 1 || padding < 2 then - Bytes.set output ((3 * i) + 1) (char_of_int y) ; - if i <> words - 1 || padding < 1 then - Bytes.set output ((3 * i) + 2) (char_of_int z) + if i <> words - 1 || padding < 2 + then Bytes.set output ((3 * i) + 1) (char_of_int y) ; + if i <> words - 1 || padding < 1 + then Bytes.set output ((3 * i) + 2) (char_of_int z) done ; Bytes.unsafe_to_string output @@ -71,7 +70,8 @@ module Old_version = struct for i = 1 to padding_len do Bytes.set output (Bytes.length output - i) padding done ; - if pad then Bytes.unsafe_to_string output + if pad + then Bytes.unsafe_to_string output else Bytes.sub_string output 0 (Bytes.length output - padding_len) end @@ -101,15 +101,10 @@ let old_encode_and_decode len = let args = [ 0; 10; 50; 100; 500; 1000; 2500; 5000 ] -let test_b64 = - Test.create_indexed ~name:"Base64" - ~args b64_encode_and_decode +let test_b64 = Test.create_indexed ~name:"Base64" ~args b64_encode_and_decode -let test_old = - Test.create_indexed ~name:"Old" - ~args old_encode_and_decode +let test_old = Test.create_indexed ~name:"Old" ~args old_encode_and_decode -let command = - Bench.make_command [ test_b64; test_old ] +let command = Bench.make_command [ test_b64; test_old ] let () = Command.run command diff --git a/bench/dune b/bench/dune index 15e24b7..8c5b1e0 100644 --- a/bench/dune +++ b/bench/dune @@ -1,3 +1,3 @@ (executable (name benchmarks) - (libraries base64 core_bench)) \ No newline at end of file + (libraries base64 core_bench)) diff --git a/config/config.ml b/config/config.ml index 8ea3416..d5b3a76 100644 --- a/config/config.ml +++ b/config/config.ml @@ -1,48 +1,54 @@ module Config = Configurator.V1 -let pre407 = {ocaml|external unsafe_set_uint16 : bytes -> int -> int -> unit = "%caml_string_set16u" [@@noalloc]|ocaml} -let standard = {ocaml|external unsafe_set_uint16 : bytes -> int -> int -> unit = "%caml_bytes_set16u" [@@noalloc]|ocaml} +let pre407 = + {ocaml|external unsafe_set_uint16 : bytes -> int -> int -> unit = "%caml_string_set16u" [@@noalloc]|ocaml} -type t = - { major : int - ; minor : int - ; patch : int option - ; extra : string option } +let standard = + {ocaml|external unsafe_set_uint16 : bytes -> int -> int -> unit = "%caml_bytes_set16u" [@@noalloc]|ocaml} -let v ?patch ?extra major minor = { major; minor; patch; extra; } +type t = { major : int; minor : int; patch : int option; extra : string option } + +let v ?patch ?extra major minor = { major; minor; patch; extra } let parse s = - try Scanf.sscanf s "%d.%d.%d+%s" (fun major minor patch extra -> v ~patch ~extra major minor) - with End_of_file | Scanf.Scan_failure _ -> - ( try Scanf.sscanf s "%d.%d+%s" (fun major minor extra -> v ~extra major minor) + try + Scanf.sscanf s "%d.%d.%d+%s" (fun major minor patch extra -> + v ~patch ~extra major minor) + with End_of_file | Scanf.Scan_failure _ -> ( + try + Scanf.sscanf s "%d.%d+%s" (fun major minor extra -> v ~extra major minor) + with End_of_file | Scanf.Scan_failure _ -> ( + try + Scanf.sscanf s "%d.%d.%d" (fun major minor patch -> + v ~patch major minor) with End_of_file | Scanf.Scan_failure _ -> - ( try Scanf.sscanf s "%d.%d.%d" (fun major minor patch -> v ~patch major minor) - with End_of_file | Scanf.Scan_failure _ -> - Scanf.sscanf s "%d.%d" (fun major minor -> v major minor) ) ) + Scanf.sscanf s "%d.%d" (fun major minor -> v major minor))) -let ( >|= ) x f = match x with - | Some x -> Some (f x ) - | None -> None +let ( >|= ) x f = match x with Some x -> Some (f x) | None -> None let ocaml_cp ~src ~dst = let ic = open_in src in let oc = open_out dst in let bf = Bytes.create 0x1000 in - let rec go () = match input ic bf 0 (Bytes.length bf) with + let rec go () = + match input ic bf 0 (Bytes.length bf) with | 0 -> () - | len -> output oc bf 0 len ; go () + | len -> + output oc bf 0 len ; + go () | exception End_of_file -> () in - go () ; close_in ic ; close_out oc -;; + go () ; + close_in ic ; + close_out oc let () = Config.main ~name:"config-base64" @@ fun t -> match Config.ocaml_config_var t "version" >|= parse with | Some version -> - let dst = "unsafe.ml" in + let dst = "unsafe.ml" in - if (version.major, version.minor) >= (4, 7) - then ocaml_cp ~src:"unsafe_stable.ml" ~dst - else ocaml_cp ~src:"unsafe_pre407.ml" ~dst + if (version.major, version.minor) >= (4, 7) + then ocaml_cp ~src:"unsafe_stable.ml" ~dst + else ocaml_cp ~src:"unsafe_pre407.ml" ~dst | None -> Config.die "OCaml version is not available" | exception exn -> Config.die "Got an exception: %s" (Printexc.to_string exn) diff --git a/dune-project b/dune-project index c305290..2af7ef7 100644 --- a/dune-project +++ b/dune-project @@ -1,3 +1,2 @@ -(lang dune 1.0) +(lang dune 2.0) (name base64) -(version dev) diff --git a/fuzz/dune b/fuzz/dune index 92b4af1..9dd8385 100644 --- a/fuzz/dune +++ b/fuzz/dune @@ -6,4 +6,4 @@ (executable (name fuzz_rfc4648) (modules fuzz_rfc4648) - (libraries astring crowbar fmt base64)) \ No newline at end of file + (libraries astring crowbar fmt base64)) diff --git a/fuzz/fuzz_rfc2045.ml b/fuzz/fuzz_rfc2045.ml index 81f2c64..7b76488 100644 --- a/fuzz/fuzz_rfc2045.ml +++ b/fuzz/fuzz_rfc2045.ml @@ -1,6 +1,7 @@ open Crowbar exception Encode_error of string + exception Decode_error of string (** Pretty printers *) @@ -9,13 +10,14 @@ let register_printer () = Printexc.register_printer (function | Encode_error err -> Some (Fmt.strf "(Encoding error: %s)" err) | Decode_error err -> Some (Fmt.strf "(Decoding error: %s)" err) - | _ -> None ) + | _ -> None) let pp_chr = let escaped = function ' ' .. '~' as c -> String.make 1 c | _ -> "." in Fmt.using escaped Fmt.string -let pp_scalar : type buffer. +let pp_scalar : + type buffer. get:(buffer -> int -> char) -> length:(buffer -> int) -> buffer Fmt.t = fun ~get ~length ppf b -> let l = length b in @@ -23,8 +25,8 @@ let pp_scalar : type buffer. Fmt.pf ppf "%08x: " (i * 16) ; let j = ref 0 in while !j < 16 do - if (i * 16) + !j < l then - Fmt.pf ppf "%02x" (Char.code @@ get b ((i * 16) + !j)) + if (i * 16) + !j < l + then Fmt.pf ppf "%02x" (Char.code @@ get b ((i * 16) + !j)) else Fmt.pf ppf " " ; if !j mod 2 <> 0 then Fmt.pf ppf " " ; incr j @@ -32,7 +34,8 @@ let pp_scalar : type buffer. Fmt.pf ppf " " ; j := 0 ; while !j < 16 do - if (i * 16) + !j < l then Fmt.pf ppf "%a" pp_chr (get b ((i * 16) + !j)) + if (i * 16) + !j < l + then Fmt.pf ppf "%a" pp_chr (get b ((i * 16) + !j)) else Fmt.pf ppf " " ; incr j done ; @@ -46,10 +49,10 @@ let pp = pp_scalar ~get:String.get ~length:String.length let check_encode str = let subs = Astring.String.cuts ~sep:"\r\n" str in let check str = - if String.length str > 78 then - raise (Encode_error "too long string returned") - in - List.iter check subs ; str + if String.length str > 78 + then raise (Encode_error "too long string returned") in + List.iter check subs ; + str let encode input = let buf = Buffer.create 80 in @@ -57,7 +60,7 @@ let encode input = String.iter (fun c -> let ret = Base64_rfc2045.encode encoder (`Char c) in - match ret with `Ok -> () | _ -> assert false ) + match ret with `Ok -> () | _ -> assert false) (* XXX(dinosaure): [`Partial] can never occur. *) input ; let encode = Base64_rfc2045.encode encoder `End in @@ -68,15 +71,14 @@ let encode input = let decode input = let decoder = Base64_rfc2045.decoder (`String input) in let rec go acc = - if Base64_rfc2045.decoder_dangerous decoder then - raise (Decode_error "Dangerous input") ; + if Base64_rfc2045.decoder_dangerous decoder + then raise (Decode_error "Dangerous input") ; match Base64_rfc2045.decode decoder with | `End -> List.rev acc | `Flush output -> go (output :: acc) | `Malformed _ -> raise (Decode_error "Malformed") | `Wrong_padding -> raise (Decode_error "Wrong padding") - | _ -> (* XXX(dinosaure): [`Await] can never occur. *) assert false - in + | _ -> (* XXX(dinosaure): [`Await] can never occur. *) assert false in String.concat "" (go []) (** String generators *) @@ -84,7 +86,7 @@ let decode input = let bytes_fixed_range : string gen = dynamic_bind (range 78) bytes_fixed let char_from_alpha alpha : string gen = - map [range (String.length alpha)] (fun i -> alpha.[i] |> String.make 1) + map [ range (String.length alpha) ] (fun i -> alpha.[i] |> String.make 1) let string_from_alpha n = let acc = const "" in @@ -93,9 +95,8 @@ let string_from_alpha n = | 0 -> acc | n -> add_char_from_alpha alpha - (concat_gen_list (const "") [acc; char_from_alpha alpha]) - (n - 1) - in + (concat_gen_list (const "") [ acc; char_from_alpha alpha ]) + (n - 1) in add_char_from_alpha alpha acc n let random_string_from_alpha n = dynamic_bind (range n) string_from_alpha @@ -106,23 +107,20 @@ let bytes_fixed_range_from_alpha : string gen = let set_canonic str = let l = String.length str in let to_drop = l * 6 mod 8 in - if - to_drop = 6 - (* XXX(clecat): Case when we need to drop 6 bits which means a whole letter *) + if to_drop = 6 + (* XXX(clecat): Case when we need to drop 6 bits which means a whole letter *) then String.sub str 0 (l - 1) - else if - to_drop <> 0 - (* XXX(clecat): Case when we need to drop 2 or 4 bits: we apply a mask droping the bits *) + else if to_drop <> 0 + (* XXX(clecat): Case when we need to drop 2 or 4 bits: we apply a mask droping the bits *) then ( let buf = Bytes.of_string str in let value = - String.index Base64_rfc2045.default_alphabet (Bytes.get buf (l - 1)) - in + String.index Base64_rfc2045.default_alphabet (Bytes.get buf (l - 1)) in let canonic = Base64_rfc2045.default_alphabet.[value land lnot ((1 lsl to_drop) - 1)] in Bytes.set buf (l - 1) canonic ; - Bytes.unsafe_to_string buf ) + Bytes.unsafe_to_string buf) else str let add_padding str = @@ -140,19 +138,18 @@ let e2d inputs = let d2e inputs end_input = let end_input = add_padding end_input in - let inputs = inputs @ [end_input] in + let inputs = inputs @ [ end_input ] in let input = List.fold_left (fun acc s -> if String.length s <> 0 then acc ^ "\r\n" ^ s else acc) - (List.hd inputs) (List.tl inputs) - in + (List.hd inputs) (List.tl inputs) in let decode = decode input in let encode = encode decode in check_eq ~pp ~cmp:String.compare ~eq:String.equal input encode let () = register_printer () ; - add_test ~name:"rfc2045: encode -> decode" [list bytes_fixed_range] e2d ; + add_test ~name:"rfc2045: encode -> decode" [ list bytes_fixed_range ] e2d ; add_test ~name:"rfc2045: decode -> encode" - [list (string_from_alpha 76); random_string_from_alpha 76] + [ list (string_from_alpha 76); random_string_from_alpha 76 ] d2e diff --git a/fuzz/fuzz_rfc4648.ml b/fuzz/fuzz_rfc4648.ml index 67d3b7d..ecab44d 100644 --- a/fuzz/fuzz_rfc4648.ml +++ b/fuzz/fuzz_rfc4648.ml @@ -4,7 +4,8 @@ let pp_chr = let escaped = function ' ' .. '~' as c -> String.make 1 c | _ -> "." in Fmt.using escaped Fmt.string -let pp_scalar : type buffer. +let pp_scalar : + type buffer. get:(buffer -> int -> char) -> length:(buffer -> int) -> buffer Fmt.t = fun ~get ~length ppf b -> let l = length b in @@ -12,8 +13,8 @@ let pp_scalar : type buffer. Fmt.pf ppf "%08x: " (i * 16) ; let j = ref 0 in while !j < 16 do - if (i * 16) + !j < l then - Fmt.pf ppf "%02x" (Char.code @@ get b ((i * 16) + !j)) + if (i * 16) + !j < l + then Fmt.pf ppf "%02x" (Char.code @@ get b ((i * 16) + !j)) else Fmt.pf ppf " " ; if !j mod 2 <> 0 then Fmt.pf ppf " " ; incr j @@ -21,7 +22,8 @@ let pp_scalar : type buffer. Fmt.pf ppf " " ; j := 0 ; while !j < 16 do - if (i * 16) + !j < l then Fmt.pf ppf "%a" pp_chr (get b ((i * 16) + !j)) + if (i * 16) + !j < l + then Fmt.pf ppf "%a" pp_chr (get b ((i * 16) + !j)) else Fmt.pf ppf " " ; incr j done ; @@ -30,94 +32,96 @@ let pp_scalar : type buffer. let pp = pp_scalar ~get:String.get ~length:String.length -let (<.>) f g x = f (g x) +let ( <.> ) f g x = f (g x) let char_from_alphabet alphabet : string gen = - map [ range 64 ] (String.make 1 <.> Char.chr <.> Array.unsafe_get (Base64.alphabet alphabet)) + map [ range 64 ] + (String.make 1 <.> Char.chr <.> Array.unsafe_get (Base64.alphabet alphabet)) let random_string_from_alphabet alphabet len : string gen = let rec add_char_from_alphabet acc = function - | 0 -> acc - | n -> - add_char_from_alphabet - (concat_gen_list (const "") [ acc ; char_from_alphabet alphabet ]) - (n - 1) in + | 0 -> acc + | n -> + add_char_from_alphabet + (concat_gen_list (const "") [ acc; char_from_alphabet alphabet ]) + (n - 1) in add_char_from_alphabet (const "") len let random_string_from_alphabet ~max alphabet = - dynamic_bind (range max) - @@ fun real_len -> - dynamic_bind (random_string_from_alphabet alphabet real_len) - @@ fun input -> - if real_len <= 1 then const (input, 0, real_len) - else dynamic_bind (range (real_len / 2)) - @@ fun off -> map [ range (real_len - off) ] (fun len -> (input, off, len)) + dynamic_bind (range max) @@ fun real_len -> + dynamic_bind (random_string_from_alphabet alphabet real_len) @@ fun input -> + if real_len <= 1 + then const (input, 0, real_len) + else + dynamic_bind (range (real_len / 2)) @@ fun off -> + map [ range (real_len - off) ] (fun len -> (input, off, len)) let encode_and_decode (input, off, len) = match Base64.encode ~pad:true ~off ~len input with | Error (`Msg err) -> fail err | Ok result -> - match Base64.decode ~pad:true result with - | Error (`Msg err) -> fail err - | Ok result -> - check_eq ~pp ~cmp:String.compare ~eq:String.equal result (String.sub input off len) + match Base64.decode ~pad:true result with + | Error (`Msg err) -> fail err + | Ok result -> + check_eq ~pp ~cmp:String.compare ~eq:String.equal result + (String.sub input off len) let decode_and_encode (input, off, len) = match Base64.decode ~pad:true ~off ~len input with - | Error (`Msg err) -> - fail err + | Error (`Msg err) -> fail err | Ok result -> - match Base64.encode ~pad:true result with - | Error (`Msg err) -> fail err - | Ok result -> - check_eq ~pp:Fmt.string ~cmp:String.compare ~eq:String.equal result (String.sub input off len) + match Base64.encode ~pad:true result with + | Error (`Msg err) -> fail err + | Ok result -> + check_eq ~pp:Fmt.string ~cmp:String.compare ~eq:String.equal result + (String.sub input off len) -let (//) x y = +let ( // ) x y = if y < 1 then raise Division_by_zero ; if x > 0 then 1 + ((x - 1) / y) else 0 -[@@inline] + [@@inline] let canonic alphabet = let dmap = Array.make 256 (-1) in - Array.iteri (fun i x -> Array.set dmap x i) (Base64.alphabet alphabet) ; + Array.iteri (fun i x -> dmap.(x) <- i) (Base64.alphabet alphabet) ; fun (input, off, len) -> let real_len = String.length input in let input_len = len in - let normalized_len = (input_len // 4) * 4 in - if normalized_len = input_len then (input, off, input_len) - else if normalized_len - input_len = 3 then (input, off, input_len - 1) - else begin + let normalized_len = input_len // 4 * 4 in + if normalized_len = input_len + then (input, off, input_len) + else if normalized_len - input_len = 3 + then (input, off, input_len - 1) + else let remainder_len = normalized_len - input_len in - let last = String.get input (off + input_len - 1) in + let last = input.[off + input_len - 1] in let output = Bytes.make (max real_len (off + normalized_len)) '=' in - Bytes.blit_string input 0 output 0 (off + input_len); + Bytes.blit_string input 0 output 0 (off + input_len) ; if off + normalized_len < real_len - then Bytes.blit_string input (off + normalized_len) output (off + normalized_len) (real_len - (off + normalized_len)) ; - - let mask = match remainder_len with - | 1 -> 0x3c - | 2 -> 0x30 - | _ -> assert false in - let decoded = Array.get dmap (Char.code last) in - let canonic = (decoded land mask) in - let encoded = Array.get (Base64.alphabet alphabet) canonic in + then + Bytes.blit_string input (off + normalized_len) output + (off + normalized_len) + (real_len - (off + normalized_len)) ; + + let mask = + match remainder_len with 1 -> 0x3c | 2 -> 0x30 | _ -> assert false in + let decoded = dmap.(Char.code last) in + let canonic = decoded land mask in + let encoded = (Base64.alphabet alphabet).(canonic) in Bytes.set output (off + input_len - 1) (Char.chr encoded) ; (Bytes.unsafe_to_string output, off, normalized_len) - end let isomorphism0 (input, off, len) = (* x0 = decode(input) && x1 = decode(encode(x0)) && x0 = x1 *) match Base64.decode ~pad:false ~off ~len input with - | Error (`Msg err) -> - fail err - | Ok result0 -> + | Error (`Msg err) -> fail err + | Ok result0 -> ( let result1 = Base64.encode_exn result0 in match Base64.decode ~pad:true result1 with - | Error (`Msg err) -> - fail err + | Error (`Msg err) -> fail err | Ok result2 -> - check_eq ~pp ~cmp:String.compare ~eq:String.equal result0 result2 + check_eq ~pp ~cmp:String.compare ~eq:String.equal result0 result2) let isomorphism1 (input, off, len) = let result0 = Base64.encode_exn ~off ~len input in @@ -125,31 +129,42 @@ let isomorphism1 (input, off, len) = | Error (`Msg err) -> fail err | Ok result1 -> let result2 = Base64.encode_exn result1 in - check_eq ~pp:Fmt.string ~cmp:String.compare ~eq:String.equal result0 result2 + check_eq ~pp:Fmt.string ~cmp:String.compare ~eq:String.equal result0 + result2 let bytes_and_range : (string * int * int) gen = - dynamic_bind bytes - @@ fun t -> + dynamic_bind bytes @@ fun t -> let real_length = String.length t in if real_length <= 1 then const (t, 0, real_length) - else dynamic_bind (range (real_length / 2)) - @@ fun off -> + else + dynamic_bind (range (real_length / 2)) @@ fun off -> map [ range (real_length - off) ] (fun len -> (t, off, len)) let range_of_max max : (int * int) gen = - dynamic_bind (range (max / 2)) - @@ fun off -> map [ range (max - off) ] (fun len -> (off, len)) + dynamic_bind (range (max / 2)) @@ fun off -> + map [ range (max - off) ] (fun len -> (off, len)) let failf fmt = Fmt.kstrf fail fmt let no_exception pad off len input = - try let _ = Base64.decode ?pad ?off ?len ~alphabet:Base64.default_alphabet input in () + try + let _ = + Base64.decode ?pad ?off ?len ~alphabet:Base64.default_alphabet input in + () with exn -> failf "decode fails with: %s." (Printexc.to_string exn) let () = - add_test ~name:"rfc4648: encode -> decode" [ bytes_and_range ] encode_and_decode ; - add_test ~name:"rfc4648: decode -> encode" [ random_string_from_alphabet ~max:1000 Base64.default_alphabet ] (decode_and_encode <.> canonic Base64.default_alphabet) ; - add_test ~name:"rfc4648: x = decode(encode(x))" [ random_string_from_alphabet ~max:1000 Base64.default_alphabet ] isomorphism0 ; - add_test ~name:"rfc4648: x = encode(decode(x))" [ bytes_and_range ] isomorphism1 ; - add_test ~name:"rfc4648: no exception leak" [ option bool; option int; option int; bytes ] no_exception + add_test ~name:"rfc4648: encode -> decode" [ bytes_and_range ] + encode_and_decode ; + add_test ~name:"rfc4648: decode -> encode" + [ random_string_from_alphabet ~max:1000 Base64.default_alphabet ] + (decode_and_encode <.> canonic Base64.default_alphabet) ; + add_test ~name:"rfc4648: x = decode(encode(x))" + [ random_string_from_alphabet ~max:1000 Base64.default_alphabet ] + isomorphism0 ; + add_test ~name:"rfc4648: x = encode(decode(x))" [ bytes_and_range ] + isomorphism1 ; + add_test ~name:"rfc4648: no exception leak" + [ option bool; option int; option int; bytes ] + no_exception diff --git a/src/base64.ml b/src/base64.ml index 10297f4..7cb22e4 100644 --- a/src/base64.ml +++ b/src/base64.ml @@ -19,42 +19,53 @@ * *) -type alphabet = - { emap : int array - ; dmap : int array } +type alphabet = { emap : int array; dmap : int array } type sub = string * int * int -let (//) x y = +let ( // ) x y = if y < 1 then raise Division_by_zero ; if x > 0 then 1 + ((x - 1) / y) else 0 -[@@inline] + [@@inline] let unsafe_get_uint8 t off = Char.code (String.unsafe_get t off) + let unsafe_set_uint8 t off v = Bytes.unsafe_set t off (Char.chr v) + let unsafe_set_uint16 = Unsafe.unsafe_set_uint16 -external unsafe_get_uint16 : string -> int -> int = "%caml_string_get16u" [@@noalloc] +external unsafe_get_uint16 : string -> int -> int = "%caml_string_get16u" + [@@noalloc] + external swap16 : int -> int = "%bswap16" [@@noalloc] -let none = (-1) +let none = -1 (* We mostly want to have an optional array for [dmap] (e.g. [int option array]). So we consider the [none] value as [-1]. *) let make_alphabet alphabet = - if String.length alphabet <> 64 then invalid_arg "Length of alphabet must be 64" ; - if String.contains alphabet '=' then invalid_arg "Alphabet can not contain padding character" ; - let emap = Array.init (String.length alphabet) (fun i -> Char.code (String.get alphabet i)) in + if String.length alphabet <> 64 + then invalid_arg "Length of alphabet must be 64" ; + if String.contains alphabet '=' + then invalid_arg "Alphabet can not contain padding character" ; + let emap = + Array.init (String.length alphabet) (fun i -> Char.code alphabet.[i]) in let dmap = Array.make 256 none in - String.iteri (fun idx chr -> Array.set dmap (Char.code chr) idx) alphabet ; - { emap; dmap; } + String.iteri (fun idx chr -> dmap.(Char.code chr) <- idx) alphabet ; + { emap; dmap } let length_alphabet { emap; _ } = Array.length emap + let alphabet { emap; _ } = emap -let default_alphabet = make_alphabet "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" -let uri_safe_alphabet = make_alphabet "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" +let default_alphabet = + make_alphabet + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" + +let uri_safe_alphabet = + make_alphabet + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" let unsafe_set_be_uint16 = if Sys.big_endian @@ -65,6 +76,7 @@ let unsafe_set_be_uint16 = can raise and avoid appearance of unknown exceptions like an ex-nihilo magic rabbit (or magic money?). *) exception Out_of_bounds + exception Too_much_input let get_uint8 t off = @@ -76,52 +88,61 @@ let padding = int_of_char '=' let error_msgf fmt = Format.ksprintf (fun err -> Error (`Msg err)) fmt let encode_sub pad { emap; _ } ?(off = 0) ?len input = - let len = match len with - | Some len -> len - | None -> String.length input - off in + let len = + match len with Some len -> len | None -> String.length input - off in if len < 0 || off < 0 || off > String.length input - len then error_msgf "Invalid bounds" else - let n = len in - let n' = n // 3 * 4 in - let res = Bytes.create n' in - - let emap i = Array.unsafe_get emap i in - - let emit b1 b2 b3 i = - unsafe_set_be_uint16 res i - ((emap (b1 lsr 2 land 0x3f) lsl 8) - lor (emap ((b1 lsl 4) lor (b2 lsr 4) land 0x3f))) ; - unsafe_set_be_uint16 res (i + 2) - ((emap ((b2 lsl 2) lor (b3 lsr 6) land 0x3f) lsl 8) - lor (emap (b3 land 0x3f))) in - - let rec enc j i = - if i = n then () - else if i = n - 1 - then emit (unsafe_get_uint8 input (off + i)) 0 0 j - else if i = n - 2 - then emit (unsafe_get_uint8 input (off + i)) (unsafe_get_uint8 input (off + i + 1)) 0 j - else - (emit - (unsafe_get_uint8 input (off + i)) - (unsafe_get_uint8 input (off + i + 1)) - (unsafe_get_uint8 input (off + i + 2)) - j ; - enc (j + 4) (i + 3)) in - - let rec unsafe_fix = function - | 0 -> () - | i -> unsafe_set_uint8 res (n' - i) padding ; unsafe_fix (i - 1) in - - enc 0 0 ; - - let pad_to_write = ((3 - n mod 3) mod 3) in - - if pad - then begin unsafe_fix pad_to_write ; Ok (Bytes.unsafe_to_string res, 0, n') end - else Ok (Bytes.unsafe_to_string res, 0, (n' - pad_to_write)) + let n = len in + let n' = n // 3 * 4 in + let res = Bytes.create n' in + + let emap i = Array.unsafe_get emap i in + + let emit b1 b2 b3 i = + unsafe_set_be_uint16 res i + ((emap ((b1 lsr 2) land 0x3f) lsl 8) + lor emap ((b1 lsl 4) lor (b2 lsr 4) land 0x3f)) ; + unsafe_set_be_uint16 res (i + 2) + ((emap ((b2 lsl 2) lor (b3 lsr 6) land 0x3f) lsl 8) + lor emap (b3 land 0x3f)) in + + let rec enc j i = + if i = n + then () + else if i = n - 1 + then emit (unsafe_get_uint8 input (off + i)) 0 0 j + else if i = n - 2 + then + emit + (unsafe_get_uint8 input (off + i)) + (unsafe_get_uint8 input (off + i + 1)) + 0 j + else ( + emit + (unsafe_get_uint8 input (off + i)) + (unsafe_get_uint8 input (off + i + 1)) + (unsafe_get_uint8 input (off + i + 2)) + j ; + enc (j + 4) (i + 3)) in + + let rec unsafe_fix = function + | 0 -> () + | i -> + unsafe_set_uint8 res (n' - i) padding ; + unsafe_fix (i - 1) in + + enc 0 0 ; + + let pad_to_write = (3 - (n mod 3)) mod 3 in + + if pad + then ( + unsafe_fix pad_to_write ; + Ok (Bytes.unsafe_to_string res, 0, n')) + else Ok (Bytes.unsafe_to_string res, 0, n' - pad_to_write) + (* [pad = false], we don't want to write them. *) let encode ?(pad = true) ?(alphabet = default_alphabet) ?off ?len input = @@ -143,129 +164,135 @@ let encode_exn ?pad ?alphabet ?off ?len input = | Error (`Msg err) -> invalid_arg err let decode_sub ?(pad = true) { dmap; _ } ?(off = 0) ?len input = - let len = match len with - | Some len -> len - | None -> String.length input - off in + let len = + match len with Some len -> len | None -> String.length input - off in if len < 0 || off < 0 || off > String.length input - len then error_msgf "Invalid bounds" else - - let n = (len // 4) * 4 in - let n' = (n // 4) * 3 in - let res = Bytes.create n' in - let invalid_pad_overflow = pad in - - let get_uint8_or_padding = - if pad then (fun t i -> if i >= len then raise Out_of_bounds ; get_uint8 t (off + i) ) - else (fun t i -> try if i < len then get_uint8 t (off + i) else padding with Out_of_bounds -> padding ) in - - let set_be_uint16 t off v = - (* can not write 2 bytes. *) - if off < 0 || off + 1 > Bytes.length t then () - (* can not write 1 byte but can write 1 byte *) - else if off < 0 || off + 2 > Bytes.length t then unsafe_set_uint8 t off (v lsr 8) - (* can write 2 bytes. *) - else unsafe_set_be_uint16 t off v in - - let set_uint8 t off v = - if off < 0 || off >= Bytes.length t then () - else unsafe_set_uint8 t off v in - - let emit a b c d j = - let x = (a lsl 18) lor (b lsl 12) lor (c lsl 6) lor d in - set_be_uint16 res j (x lsr 8) ; - set_uint8 res (j + 2) (x land 0xff) in - - let dmap i = - let x = Array.unsafe_get dmap i in - if x = none then raise Not_found ; x in - - let only_padding pad idx = - - (* because we round length of [res] to the upper bound of how many - characters we should have from [input], we got at this stage only padding - characters and we need to delete them, so for each [====], we delete 3 - bytes. *) - - let pad = ref (pad + 3) in - let idx = ref idx in - - while !idx + 4 < len do - (* use [unsafe_get_uint16] instead [unsafe_get_uint32] to avoid allocation - of [int32]. Of course, [3d3d3d3d] is [====]. *) - if unsafe_get_uint16 input (off + !idx) <> 0x3d3d - || unsafe_get_uint16 input (off + !idx + 2) <> 0x3d3d - then raise Not_found ; - (* We got something bad, should be a valid character according to - [alphabet] but outside the scope. *) - - idx := !idx + 4 ; - pad := !pad + 3 ; - done ; - while !idx < len do - if unsafe_get_uint8 input (off + !idx) <> padding - then raise Not_found ; - - incr idx ; - done ; !pad in - - let rec dec j i = - if i = n then 0 - else begin - let (d, pad) = - let x = get_uint8_or_padding input (i + 3) in - try (dmap x, 0) with Not_found when x = padding -> (0, 1) in - (* [Not_found] iff [x ∉ alphabet and x <> '='] can leak. *) - let (c, pad) = - let x = get_uint8_or_padding input (i + 2) in - try (dmap x, pad) with Not_found when x = padding && pad = 1 -> (0, 2) in - (* [Not_found] iff [x ∉ alphabet and x <> '='] can leak. *) - let (b, pad) = - let x = get_uint8_or_padding input (i + 1) in - try (dmap x, pad) with Not_found when x = padding && pad = 2 -> (0, 3) in - (* [Not_found] iff [x ∉ alphabet and x <> '='] can leak. *) - let (a, pad) = - let x = get_uint8_or_padding input i in - try (dmap x, pad) with Not_found when x = padding && pad = 3 -> (0, 4) in - (* [Not_found] iff [x ∉ alphabet and x <> '='] can leak. *) - - emit a b c d j ; - - if i + 4 = n - (* end of input in anyway *) - then match pad with - | 0 -> - 0 - | 4 -> - (* assert (invalid_pad_overflow = false) ; *) - 3 - (* [get_uint8] lies and if we get [4], that mean we got one or more (at - most 4) padding character. In this situation, because we round length - of [res] (see [n // 4]), we need to delete 3 bytes. *) - | pad -> - pad - else match pad with - | 0 -> dec (j + 3) (i + 4) - | 4 -> - (* assert (invalid_pad_overflow = false) ; *) - only_padding 3 (i + 4) - (* Same situation than above but we should get only more padding - characters then. *) - | pad -> - if invalid_pad_overflow = true then raise Too_much_input ; - only_padding pad (i + 4) end in - - match dec 0 0 with - | 0 -> Ok (Bytes.unsafe_to_string res, 0, n') - | pad -> Ok (Bytes.unsafe_to_string res, 0, (n' - pad)) - | exception Out_of_bounds -> error_msgf "Wrong padding" - (* appear only when [pad = true] and when length of input is not a multiple of 4. *) - | exception Not_found -> - (* appear when one character of [input] ∉ [alphabet] and this character <> '=' *) - error_msgf "Malformed input" - | exception Too_much_input -> - error_msgf "Too much input" + let n = len // 4 * 4 in + let n' = n // 4 * 3 in + let res = Bytes.create n' in + let invalid_pad_overflow = pad in + + let get_uint8_or_padding = + if pad + then (fun t i -> + if i >= len then raise Out_of_bounds ; + get_uint8 t (off + i)) + else + fun t i -> + try if i < len then get_uint8 t (off + i) else padding + with Out_of_bounds -> padding in + + let set_be_uint16 t off v = + (* can not write 2 bytes. *) + if off < 0 || off + 1 > Bytes.length t + then () (* can not write 1 byte but can write 1 byte *) + else if off < 0 || off + 2 > Bytes.length t + then unsafe_set_uint8 t off (v lsr 8) (* can write 2 bytes. *) + else unsafe_set_be_uint16 t off v in + + let set_uint8 t off v = + if off < 0 || off >= Bytes.length t then () else unsafe_set_uint8 t off v + in + + let emit a b c d j = + let x = (a lsl 18) lor (b lsl 12) lor (c lsl 6) lor d in + set_be_uint16 res j (x lsr 8) ; + set_uint8 res (j + 2) (x land 0xff) in + + let dmap i = + let x = Array.unsafe_get dmap i in + if x = none then raise Not_found ; + x in + + let only_padding pad idx = + (* because we round length of [res] to the upper bound of how many + characters we should have from [input], we got at this stage only padding + characters and we need to delete them, so for each [====], we delete 3 + bytes. *) + let pad = ref (pad + 3) in + let idx = ref idx in + + while !idx + 4 < len do + (* use [unsafe_get_uint16] instead [unsafe_get_uint32] to avoid allocation + of [int32]. Of course, [3d3d3d3d] is [====]. *) + if unsafe_get_uint16 input (off + !idx) <> 0x3d3d + || unsafe_get_uint16 input (off + !idx + 2) <> 0x3d3d + then raise Not_found ; + + (* We got something bad, should be a valid character according to + [alphabet] but outside the scope. *) + idx := !idx + 4 ; + pad := !pad + 3 + done ; + while !idx < len do + if unsafe_get_uint8 input (off + !idx) <> padding then raise Not_found ; + + incr idx + done ; + !pad in + + let rec dec j i = + if i = n + then 0 + else + let d, pad = + let x = get_uint8_or_padding input (i + 3) in + try (dmap x, 0) with Not_found when x = padding -> (0, 1) in + (* [Not_found] iff [x ∉ alphabet and x <> '='] can leak. *) + let c, pad = + let x = get_uint8_or_padding input (i + 2) in + try (dmap x, pad) + with Not_found when x = padding && pad = 1 -> (0, 2) in + (* [Not_found] iff [x ∉ alphabet and x <> '='] can leak. *) + let b, pad = + let x = get_uint8_or_padding input (i + 1) in + try (dmap x, pad) + with Not_found when x = padding && pad = 2 -> (0, 3) in + (* [Not_found] iff [x ∉ alphabet and x <> '='] can leak. *) + let a, pad = + let x = get_uint8_or_padding input i in + try (dmap x, pad) + with Not_found when x = padding && pad = 3 -> (0, 4) in + + (* [Not_found] iff [x ∉ alphabet and x <> '='] can leak. *) + emit a b c d j ; + + if i + 4 = n (* end of input in anyway *) + then + match pad with + | 0 -> 0 + | 4 -> + (* assert (invalid_pad_overflow = false) ; *) + 3 + (* [get_uint8] lies and if we get [4], that mean we got one or more (at + most 4) padding character. In this situation, because we round length + of [res] (see [n // 4]), we need to delete 3 bytes. *) + | pad -> pad + else + match pad with + | 0 -> dec (j + 3) (i + 4) + | 4 -> + (* assert (invalid_pad_overflow = false) ; *) + only_padding 3 (i + 4) + (* Same situation than above but we should get only more padding + characters then. *) + | pad -> + if invalid_pad_overflow = true then raise Too_much_input ; + only_padding pad (i + 4) in + + match dec 0 0 with + | 0 -> Ok (Bytes.unsafe_to_string res, 0, n') + | pad -> Ok (Bytes.unsafe_to_string res, 0, n' - pad) + | exception Out_of_bounds -> + error_msgf "Wrong padding" + (* appear only when [pad = true] and when length of input is not a multiple of 4. *) + | exception Not_found -> + (* appear when one character of [input] ∉ [alphabet] and this character <> '=' *) + error_msgf "Malformed input" + | exception Too_much_input -> error_msgf "Too much input" let decode ?pad ?(alphabet = default_alphabet) ?off ?len input = match decode_sub ?pad alphabet ?off ?len input with diff --git a/src/base64.mli b/src/base64.mli index 6857a85..a82b55e 100644 --- a/src/base64.mli +++ b/src/base64.mli @@ -48,13 +48,14 @@ val length_alphabet : alphabet -> int val alphabet : alphabet -> int array (** Returns the alphabet. *) -val decode_exn : ?pad:bool -> ?alphabet:alphabet -> ?off:int -> ?len:int -> string -> string -(** [decode_exn ?off ?len s] decodes [len] bytes (defaults to [String.length s - - off]) of the string [s] starting from [off] (defaults to [0]) that is encoded - in Base64 format. Will leave trailing NULLs on the string, padding it out to - a multiple of 3 characters. [alphabet] defaults to {!default_alphabet}. [pad - = true] specifies to check if [s] is padded or not, otherwise, it raises an - exception. +val decode_exn : + ?pad:bool -> ?alphabet:alphabet -> ?off:int -> ?len:int -> string -> string +(** [decode_exn ?off ?len s] decodes [len] bytes (defaults to + [String.length s - off]) of the string [s] starting from [off] (defaults to + [0]) that is encoded in Base64 format. Will leave trailing NULLs on the + string, padding it out to a multiple of 3 characters. [alphabet] defaults to + {!default_alphabet}. [pad = true] specifies to check if [s] is padded or + not, otherwise, it raises an exception. Decoder can fail when character of [s] is not a part of [alphabet] or is not [padding] character. If input is not padded correctly, decoder does the @@ -62,16 +63,35 @@ val decode_exn : ?pad:bool -> ?alphabet:alphabet -> ?off:int -> ?len:int -> stri @raise if Invalid_argument [s] is not a valid Base64 string. *) -val decode_sub : ?pad:bool -> ?alphabet:alphabet -> ?off:int -> ?len:int -> string -> (sub, [> `Msg of string ]) result +val decode_sub : + ?pad:bool -> + ?alphabet:alphabet -> + ?off:int -> + ?len:int -> + string -> + (sub, [> `Msg of string ]) result (** Same as {!decode_exn} but it returns a result type instead to raise an exception. Then, it returns a {!sub} string. Decoded input [(str, off, len)] will starting to [off] and will have [len] bytes - by this way, we ensure to allocate only one time result. *) -val decode : ?pad:bool -> ?alphabet:alphabet -> ?off:int -> ?len:int -> string -> (string, [> `Msg of string ]) result -(** Same as {!decode_exn}, but returns an explicit error message {!result} if it fails. *) - -val encode : ?pad:bool -> ?alphabet:alphabet -> ?off:int -> ?len:int -> string -> (string, [> `Msg of string]) result +val decode : + ?pad:bool -> + ?alphabet:alphabet -> + ?off:int -> + ?len:int -> + string -> + (string, [> `Msg of string ]) result +(** Same as {!decode_exn}, but returns an explicit error message {!result} if it + fails. *) + +val encode : + ?pad:bool -> + ?alphabet:alphabet -> + ?off:int -> + ?len:int -> + string -> + (string, [> `Msg of string ]) result (** [encode s] encodes the string [s] into base64. If [pad] is false, no trailing padding is added. [pad] defaults to [true], and [alphabet] to {!default_alphabet}. @@ -83,10 +103,17 @@ val encode_string : ?pad:bool -> ?alphabet:alphabet -> string -> string trailing padding is added. [pad] defaults to [true], and [alphabet] to {!default_alphabet}. *) -val encode_sub : ?pad:bool -> ?alphabet:alphabet -> ?off:int -> ?len:int -> string -> (sub, [> `Msg of string]) result +val encode_sub : + ?pad:bool -> + ?alphabet:alphabet -> + ?off:int -> + ?len:int -> + string -> + (sub, [> `Msg of string ]) result (** Same as {!encode} but return a {!sub}-string instead a plain result. By this way, we ensure to allocate only one time result. *) -val encode_exn : ?pad:bool -> ?alphabet:alphabet -> ?off:int -> ?len:int -> string -> string +val encode_exn : + ?pad:bool -> ?alphabet:alphabet -> ?off:int -> ?len:int -> string -> string (** Same as {!encode} but raises an invalid argument exception if we retrieve an error. *) diff --git a/src/base64_rfc2045.ml b/src/base64_rfc2045.ml index 8018fe7..8375700 100644 --- a/src/base64_rfc2045.ml +++ b/src/base64_rfc2045.ml @@ -19,28 +19,32 @@ let default_alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" let io_buffer_size = 65536 + let invalid_arg fmt = Format.ksprintf (fun s -> invalid_arg s) fmt let invalid_bounds off len = invalid_arg "Invalid bounds (off: %d, len: %d)" off len -let malformed chr = - `Malformed (String.make 1 chr) +let malformed chr = `Malformed (String.make 1 chr) let unsafe_byte source off pos = Bytes.unsafe_get source (off + pos) + let unsafe_blit = Bytes.unsafe_blit + let unsafe_chr = Char.unsafe_chr + let unsafe_set_chr source off chr = Bytes.unsafe_set source off chr -type state = {quantum: int; size: int; buffer: Bytes.t} +type state = { quantum : int; size : int; buffer : Bytes.t } -let continue state (quantum, size) = `Continue {state with quantum; size} -let flush state = `Flush {state with quantum= 0; size= 0} +let continue state (quantum, size) = `Continue { state with quantum; size } + +let flush state = `Flush { state with quantum = 0; size = 0 } let table = "\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\062\255\255\255\063\052\053\054\055\056\057\058\059\060\061\255\255\255\255\255\255\255\000\001\002\003\004\005\006\007\008\009\010\011\012\013\014\015\016\017\018\019\020\021\022\023\024\025\255\255\255\255\255\255\026\027\028\029\030\031\032\033\034\035\036\037\038\039\040\041\042\043\044\045\046\047\048\049\050\051\255\255\255\255\255" -let r_repr ({quantum; size; _} as state) chr = +let r_repr ({ quantum; size; _ } as state) chr = (* assert (0 <= off && 0 <= len && off + len <= String.length source); *) (* assert (len >= 1); *) let code = Char.code table.[Char.code chr] in @@ -56,27 +60,28 @@ let r_repr ({quantum; size; _} as state) chr = flush state | _ -> malformed chr -type src = [`Channel of in_channel | `String of string | `Manual] +type src = [ `Channel of in_channel | `String of string | `Manual ] type decode = - [`Await | `End | `Wrong_padding | `Malformed of string | `Flush of string] + [ `Await | `End | `Wrong_padding | `Malformed of string | `Flush of string ] type input = - [`Line_break | `Wsp | `Padding | `Malformed of string | `Flush of state] - -type decoder = - { src: src - ; mutable i: Bytes.t - ; mutable i_off: int - ; mutable i_pos: int - ; mutable i_len: int - ; mutable s: state - ; mutable padding: int - ; mutable unsafe: bool - ; mutable byte_count: int - ; mutable limit_count: int - ; mutable pp: decoder -> input -> decode - ; mutable k: decoder -> decode } + [ `Line_break | `Wsp | `Padding | `Malformed of string | `Flush of state ] + +type decoder = { + src : src; + mutable i : Bytes.t; + mutable i_off : int; + mutable i_pos : int; + mutable i_len : int; + mutable s : state; + mutable padding : int; + mutable unsafe : bool; + mutable byte_count : int; + mutable limit_count : int; + mutable pp : decoder -> input -> decode; + mutable k : decoder -> decode; +} let i_rem decoder = decoder.i_len - decoder.i_pos + 1 @@ -87,27 +92,31 @@ let end_of_input decoder = decoder.i_len <- min_int let src decoder source off len = - if off < 0 || len < 0 || off + len > Bytes.length source then - invalid_bounds off len - else if len = 0 then end_of_input decoder + if off < 0 || len < 0 || off + len > Bytes.length source + then invalid_bounds off len + else if len = 0 + then end_of_input decoder else ( decoder.i <- source ; decoder.i_off <- off ; decoder.i_pos <- 0 ; - decoder.i_len <- len - 1 ) + decoder.i_len <- len - 1) let refill k decoder = match decoder.src with | `Manual -> decoder.k <- k ; `Await - | `String _ -> end_of_input decoder ; k decoder + | `String _ -> + end_of_input decoder ; + k decoder | `Channel ic -> let len = input ic decoder.i 0 (Bytes.length decoder.i) in src decoder decoder.i 0 len ; k decoder let dangerous decoder v = decoder.unsafe <- v + let reset decoder = decoder.limit_count <- 0 let ret k v byte_count decoder = @@ -117,9 +126,9 @@ let ret k v byte_count decoder = if decoder.limit_count > 78 then dangerous decoder true ; decoder.pp decoder v -type flush_and_malformed = [`Flush of state | `Malformed of string] +type flush_and_malformed = [ `Flush of state | `Malformed of string ] -let padding {size; _} padding = +let padding { size; _ } padding = match (size, padding) with | 0, 0 -> true | 1, _ -> false @@ -127,39 +136,42 @@ let padding {size; _} padding = | 3, 1 -> true | _ -> false -let t_flush {quantum; size; buffer} = +let t_flush { quantum; size; buffer } = match size with - | 0 | 1 -> `Flush {quantum; size; buffer= Bytes.empty} + | 0 | 1 -> `Flush { quantum; size; buffer = Bytes.empty } | 2 -> let quantum = quantum lsr 4 in `Flush - { quantum - ; size - ; buffer= Bytes.make 1 (unsafe_chr (quantum land 255)) } + { quantum; size; buffer = Bytes.make 1 (unsafe_chr (quantum land 255)) } | 3 -> let quantum = quantum lsr 2 in unsafe_set_chr buffer 0 (unsafe_chr ((quantum lsr 8) land 255)) ; unsafe_set_chr buffer 1 (unsafe_chr (quantum land 255)) ; - `Flush {quantum; size; buffer= Bytes.sub buffer 0 2} - | _ -> assert false (* this branch is impossible, size can only ever be in the range [0..3]. *) + `Flush { quantum; size; buffer = Bytes.sub buffer 0 2 } + | _ -> assert false + +(* this branch is impossible, size can only ever be in the range [0..3]. *) let wrong_padding decoder = let k _ = `End in - decoder.k <- k ; `Wrong_padding + decoder.k <- k ; + `Wrong_padding let rec t_decode_base64 chr decoder = - if decoder.padding = 0 then + if decoder.padding = 0 + then let rec go pos = function | `Continue state -> if decoder.i_len - (decoder.i_pos + pos) + 1 > 0 then ( match unsafe_byte decoder.i decoder.i_off (decoder.i_pos + pos) with - | ('A' .. 'Z' | 'a' .. 'z' | '0' .. '9' | '+' | '/') as chr -> go (succ pos) (r_repr state chr) + | ('A' .. 'Z' | 'a' .. 'z' | '0' .. '9' | '+' | '/') as chr -> + go (succ pos) (r_repr state chr) | '=' -> decoder.padding <- decoder.padding + 1 ; decoder.i_pos <- decoder.i_pos + pos + 1 ; decoder.s <- state ; - ret decode_base64 `Padding (pos+1) decoder + ret decode_base64 `Padding (pos + 1) decoder | ' ' | '\t' -> decoder.i_pos <- decoder.i_pos + pos + 1 ; decoder.s <- state ; @@ -171,17 +183,16 @@ let rec t_decode_base64 chr decoder = | chr -> decoder.i_pos <- decoder.i_pos + pos + 1 ; decoder.s <- state ; - ret decode_base64 (malformed chr) (pos+1) decoder - ) else ( + ret decode_base64 (malformed chr) (pos + 1) decoder) + else ( decoder.i_pos <- decoder.i_pos + pos ; decoder.byte_count <- decoder.byte_count + pos ; decoder.limit_count <- decoder.limit_count + pos ; decoder.s <- state ; - refill decode_base64 decoder ) + refill decode_base64 decoder) | #flush_and_malformed as v -> decoder.i_pos <- decoder.i_pos + pos ; - ret decode_base64 v pos decoder - in + ret decode_base64 v pos decoder in go 1 (r_repr decoder.s chr) else ( decoder.i_pos <- decoder.i_pos + 1 ; @@ -189,24 +200,28 @@ let rec t_decode_base64 chr decoder = and decode_base64_lf_after_cr decoder = let rem = i_rem decoder in - if rem < 0 then - ret decode_base64 (malformed '\r') 1 decoder - else if rem = 0 then refill decode_base64_lf_after_cr decoder + if rem < 0 + then ret decode_base64 (malformed '\r') 1 decoder + else if rem = 0 + then refill decode_base64_lf_after_cr decoder else match unsafe_byte decoder.i decoder.i_off decoder.i_pos with | '\n' -> - decoder.i_pos <- decoder.i_pos + 1 ; - ret decode_base64 `Line_break 2 decoder - | _ -> - ret decode_base64 (malformed '\r') 1 decoder + decoder.i_pos <- decoder.i_pos + 1 ; + ret decode_base64 `Line_break 2 decoder + | _ -> ret decode_base64 (malformed '\r') 1 decoder and decode_base64 decoder = let rem = i_rem decoder in - if rem <= 0 then - if rem < 0 then + if rem <= 0 + then + if rem < 0 + then ret (fun decoder -> - if padding decoder.s decoder.padding then `End else wrong_padding decoder ) + if padding decoder.s decoder.padding + then `End + else wrong_padding decoder) (t_flush decoder.s) 0 decoder else refill decode_base64 decoder else @@ -228,7 +243,9 @@ and decode_base64 decoder = ret decode_base64 (malformed chr) 1 decoder let pp_base64 decoder = function - | `Line_break -> reset decoder ; decoder.k decoder + | `Line_break -> + reset decoder ; + decoder.k decoder | `Wsp | `Padding -> decoder.k decoder | `Flush state -> decoder.s <- state ; @@ -242,52 +259,58 @@ let decoder src = match src with | `Manual -> (Bytes.empty, 0, 1, 0) | `Channel _ -> (Bytes.create io_buffer_size, 0, 1, 0) - | `String s -> (Bytes.unsafe_of_string s, 0, 0, String.length s - 1) - in - { src - ; i_off - ; i_pos - ; i_len - ; i - ; s= {quantum= 0; size= 0; buffer= Bytes.create 3} - ; padding= 0 - ; unsafe= false - ; byte_count= 0 - ; limit_count= 0 - ; pp - ; k } + | `String s -> (Bytes.unsafe_of_string s, 0, 0, String.length s - 1) in + { + src; + i_off; + i_pos; + i_len; + i; + s = { quantum = 0; size = 0; buffer = Bytes.create 3 }; + padding = 0; + unsafe = false; + byte_count = 0; + limit_count = 0; + pp; + k; + } let decode decoder = decoder.k decoder + let decoder_byte_count decoder = decoder.byte_count + let decoder_src decoder = decoder.src + let decoder_dangerous decoder = decoder.unsafe (* / *) let invalid_encode () = invalid_arg "Expected `Await encode" -type dst = [`Channel of out_channel | `Buffer of Buffer.t | `Manual] -type encode = [`Await | `End | `Char of char] - -type encoder = - { dst: dst - ; mutable o: Bytes.t - ; mutable o_off: int - ; mutable o_pos: int - ; mutable o_len: int - ; mutable c_col: int - ; i: Bytes.t - ; mutable s: int - ; t: Bytes.t - ; mutable t_pos: int - ; mutable t_len: int - ; mutable k: encoder -> encode -> [`Ok | `Partial] } +type dst = [ `Channel of out_channel | `Buffer of Buffer.t | `Manual ] + +type encode = [ `Await | `End | `Char of char ] + +type encoder = { + dst : dst; + mutable o : Bytes.t; + mutable o_off : int; + mutable o_pos : int; + mutable o_len : int; + mutable c_col : int; + i : Bytes.t; + mutable s : int; + t : Bytes.t; + mutable t_pos : int; + mutable t_len : int; + mutable k : encoder -> encode -> [ `Ok | `Partial ]; +} let o_rem encoder = encoder.o_len - encoder.o_pos + 1 let dst encoder source off len = - if off < 0 || len < 0 || off + len > Bytes.length source then - invalid_bounds off len ; + if off < 0 || len < 0 || off + len > Bytes.length source + then invalid_bounds off len ; encoder.o <- source ; encoder.o_off <- off ; encoder.o_pos <- 0 ; @@ -322,36 +345,37 @@ let rec t_flush k encoder = let blit encoder len = unsafe_blit encoder.t encoder.t_pos encoder.o encoder.o_pos len ; encoder.o_pos <- encoder.o_pos + len ; - encoder.t_pos <- encoder.t_pos + len - in + encoder.t_pos <- encoder.t_pos + len in let rem = o_rem encoder in let len = encoder.t_len - encoder.t_pos + 1 in - if rem < len then ( + if rem < len + then ( blit encoder rem ; - flush (t_flush k) encoder ) - else ( blit encoder len ; k encoder ) + flush (t_flush k) encoder) + else ( + blit encoder len ; + k encoder) let rec encode_line_break k encoder = let rem = o_rem encoder in let s, j, k = - if rem < 2 then ( + if rem < 2 + then ( t_range encoder 2 ; - (encoder.t, 0, t_flush k) ) + (encoder.t, 0, t_flush k)) else let j = encoder.o_pos in encoder.o_pos <- encoder.o_pos + 2 ; - (encoder.o, encoder.o_off + j, k) - in + (encoder.o, encoder.o_off + j, k) in unsafe_set_chr s j '\r' ; unsafe_set_chr s (j + 1) '\n' ; encoder.c_col <- 0 ; k encoder and encode_char chr k (encoder : encoder) = - if encoder.s >= 2 then ( - let a, b, c = - (unsafe_byte encoder.i 0 0, unsafe_byte encoder.i 0 1, chr) - in + if encoder.s >= 2 + then ( + let a, b, c = (unsafe_byte encoder.i 0 0, unsafe_byte encoder.i 0 1, chr) in encoder.s <- 0 ; let quantum = (Char.code a lsl 16) + (Char.code b lsl 8) + Char.code c in let a = quantum lsr 18 in @@ -360,23 +384,23 @@ and encode_char chr k (encoder : encoder) = let d = quantum land 63 in let rem = o_rem encoder in let s, j, k = - if rem < 4 then ( + if rem < 4 + then ( t_range encoder 4 ; - (encoder.t, 0, t_flush (k 4)) ) + (encoder.t, 0, t_flush (k 4))) else let j = encoder.o_pos in encoder.o_pos <- encoder.o_pos + 4 ; - (encoder.o, encoder.o_off + j, k 4) - in + (encoder.o, encoder.o_off + j, k 4) in unsafe_set_chr s j default_alphabet.[a] ; unsafe_set_chr s (j + 1) default_alphabet.[b] ; unsafe_set_chr s (j + 2) default_alphabet.[c] ; unsafe_set_chr s (j + 3) default_alphabet.[d] ; - flush k encoder ) + flush k encoder) else ( unsafe_set_chr encoder.i encoder.s chr ; encoder.s <- encoder.s + 1 ; - k 0 encoder ) + k 0 encoder) and encode_trailing k encoder = match encoder.s with @@ -389,14 +413,14 @@ and encode_trailing k encoder = let d = quantum land 63 in let rem = o_rem encoder in let s, j, k = - if rem < 4 then ( + if rem < 4 + then ( t_range encoder 4 ; - (encoder.t, 0, t_flush (k 4)) ) + (encoder.t, 0, t_flush (k 4))) else let j = encoder.o_pos in encoder.o_pos <- encoder.o_pos + 4 ; - (encoder.o, encoder.o_off + j, k 4) - in + (encoder.o, encoder.o_off + j, k 4) in unsafe_set_chr s j default_alphabet.[b] ; unsafe_set_chr s (j + 1) default_alphabet.[c] ; unsafe_set_chr s (j + 2) default_alphabet.[d] ; @@ -410,14 +434,14 @@ and encode_trailing k encoder = let d = quantum land 63 in let rem = o_rem encoder in let s, j, k = - if rem < 4 then ( + if rem < 4 + then ( t_range encoder 4 ; - (encoder.t, 0, t_flush (k 4)) ) + (encoder.t, 0, t_flush (k 4))) else let j = encoder.o_pos in encoder.o_pos <- encoder.o_pos + 4 ; - (encoder.o, encoder.o_off + j, k 4) - in + (encoder.o, encoder.o_off + j, k 4) in unsafe_set_chr s j default_alphabet.[c] ; unsafe_set_chr s (j + 1) default_alphabet.[d] ; unsafe_set_chr s (j + 2) '=' ; @@ -430,19 +454,19 @@ and encode_base64 encoder v = let k col_count encoder = encoder.c_col <- encoder.c_col + col_count ; encoder.k <- encode_base64 ; - `Ok - in + `Ok in match v with | `Await -> k 0 encoder | `End -> - if encoder.c_col = 76 then - encode_line_break (fun encoder -> encode_base64 encoder v) encoder + if encoder.c_col = 76 + then encode_line_break (fun encoder -> encode_base64 encoder v) encoder else encode_trailing k encoder | `Char chr -> let rem = o_rem encoder in - if rem < 1 then flush (fun encoder -> encode_base64 encoder v) encoder - else if encoder.c_col = 76 then - encode_line_break (fun encoder -> encode_base64 encoder v) encoder + if rem < 1 + then flush (fun encoder -> encode_base64 encoder v) encoder + else if encoder.c_col = 76 + then encode_line_break (fun encoder -> encode_base64 encoder v) encoder else encode_char chr k encoder let encoder dst = @@ -450,20 +474,22 @@ let encoder dst = match dst with | `Manual -> (Bytes.empty, 1, 0, 0) | `Buffer _ | `Channel _ -> - (Bytes.create io_buffer_size, 0, 0, io_buffer_size - 1) - in - { dst - ; o_off - ; o_pos - ; o_len - ; o - ; t= Bytes.create 4 - ; t_pos= 1 - ; t_len= 0 - ; c_col= 0 - ; i= Bytes.create 3 - ; s= 0 - ; k= encode_base64 } + (Bytes.create io_buffer_size, 0, 0, io_buffer_size - 1) in + { + dst; + o_off; + o_pos; + o_len; + o; + t = Bytes.create 4; + t_pos = 1; + t_len = 0; + c_col = 0; + i = Bytes.create 3; + s = 0; + k = encode_base64; + } let encode encoder = encoder.k encoder + let encoder_dst encoder = encoder.dst diff --git a/src/base64_rfc2045.mli b/src/base64_rfc2045.mli index 8fc3330..90fef82 100644 --- a/src/base64_rfc2045.mli +++ b/src/base64_rfc2045.mli @@ -20,15 +20,15 @@ val default_alphabet : string (** A 64-character string specifying the regular Base64 alphabet. *) -(** The type for decoders. *) type decoder +(** The type for decoders. *) +type src = [ `Manual | `Channel of in_channel | `String of string ] (** The type for input sources. With a [`Manual] source the client must provide input with {!src}. *) -type src = [`Manual | `Channel of in_channel | `String of string] type decode = - [`Await | `End | `Flush of string | `Malformed of string | `Wrong_padding] + [ `Await | `End | `Flush of string | `Malformed of string | `Wrong_padding ] val src : decoder -> Bytes.t -> int -> int -> unit (** [src d s j l] provides [d] with [l] bytes to read, starting at [j] in [s]. @@ -66,19 +66,19 @@ val decoder_dangerous : decoder -> bool still continue to decode even if [decoder_dangerous d] returns [true]. Nothing grow automatically internally in this state. *) +type dst = [ `Channel of out_channel | `Buffer of Buffer.t | `Manual ] (** The type for output destinations. With a [`Manual] destination the client must provide output storage with {!dst}. *) -type dst = [`Channel of out_channel | `Buffer of Buffer.t | `Manual] -type encode = [`Await | `End | `Char of char] +type encode = [ `Await | `End | `Char of char ] -(** The type for Base64 (RFC2045) encoder. *) type encoder +(** The type for Base64 (RFC2045) encoder. *) val encoder : dst -> encoder (** [encoder dst] is an encoder for Base64 (RFC2045) that outputs to [dst]. *) -val encode : encoder -> encode -> [`Ok | `Partial] +val encode : encoder -> encode -> [ `Ok | `Partial ] (** [encode e v]: is {ul {- [`Partial] iff [e] has a [`Manual] destination and needs more output storage. The client must use {!dst} to provide a new buffer and then call {!encode} with [`Await] until [`Ok] is returned.} {- @@ -99,8 +99,8 @@ val encoder_dst : encoder -> dst val dst : encoder -> Bytes.t -> int -> int -> unit (** [dst e s j l] provides [e] with [l] bytes to write, starting at [j] in [s]. This byte range is written by calls to {!encode} with [e] until [`Partial] - is returned. Use {!dst_rem} to know the remaining number of non-written - free bytes in [s]. *) + is returned. Use {!dst_rem} to know the remaining number of non-written free + bytes in [s]. *) val dst_rem : encoder -> int (** [dst_rem e] is the remaining number of non-written, free bytes in the last diff --git a/src/dune b/src/dune index f5f41eb..c260eb1 100644 --- a/src/dune +++ b/src/dune @@ -7,7 +7,8 @@ (rule (targets unsafe.ml) (deps unsafe_pre407.ml unsafe_stable.ml) - (action (run ../config/config.exe))) + (action + (run ../config/config.exe))) (library (name base64_rfc2045) diff --git a/src/unsafe_pre407.ml b/src/unsafe_pre407.ml index 86d3db9..23132b3 100644 --- a/src/unsafe_pre407.ml +++ b/src/unsafe_pre407.ml @@ -1 +1,2 @@ -external unsafe_set_uint16 : bytes -> int -> int -> unit = "%caml_string_set16u" [@@noalloc] +external unsafe_set_uint16 : bytes -> int -> int -> unit = "%caml_string_set16u" + [@@noalloc] diff --git a/src/unsafe_stable.ml b/src/unsafe_stable.ml index 01b2db4..24c2cd3 100644 --- a/src/unsafe_stable.ml +++ b/src/unsafe_stable.ml @@ -1 +1,2 @@ -external unsafe_set_uint16 : bytes -> int -> int -> unit = "%caml_bytes_set16u" [@@noalloc] +external unsafe_set_uint16 : bytes -> int -> int -> unit = "%caml_bytes_set16u" + [@@noalloc] diff --git a/test/dune b/test/dune index 8a0c40b..0207773 100644 --- a/test/dune +++ b/test/dune @@ -1,8 +1,11 @@ (executable + (modes byte exe) (name test) (libraries base64 base64.rfc2045 rresult alcotest bos)) -(alias - (name runtest) - (deps (:exe test.exe)) - (action (run %{exe} --color=always))) +(rule + (alias runtest) + (deps + (:exe test.exe)) + (action + (run %{exe} --color=always))) diff --git a/test/test.ml b/test/test.ml index e52cc19..41a4b9a 100644 --- a/test/test.ml +++ b/test/test.ml @@ -28,203 +28,260 @@ open Rresult BASE64("foobar") = "Zm9vYmFy" *) -let rfc4648_tests = [ - "", ""; - "f", "Zg=="; - "fo", "Zm8="; - "foo", "Zm9v"; - "foob", "Zm9vYg=="; - "fooba", "Zm9vYmE="; - "foobar", "Zm9vYmFy"; -] - -let hannes_tests = [ - "dummy", "ZHVtbXk="; - "dummy", "ZHVtbXk"; - "dummy", "ZHVtbXk=="; - "dummy", "ZHVtbXk==="; - "dummy", "ZHVtbXk===="; - "dummy", "ZHVtbXk====="; - "dummy", "ZHVtbXk======"; -] - -let php_tests = [ - "πάντα χωρεῖ καὶ οὐδὲν μένει …", "z4DOrM69z4TOsSDPh8-Jz4HOteG_liDOus6x4b22IM6_4b2QzrThvbLOvSDOvM6tzr3Otc65IOKApg" -] - -let rfc3548_tests = [ - "\x14\xfb\x9c\x03\xd9\x7e", "FPucA9l+"; - "\x14\xfb\x9c\x03\xd9", "FPucA9k="; - "\x14\xfb\x9c\x03", "FPucAw=="; -] - -let cfcs_tests = [ - 0, 2, "\004", "BB"; - 1, 2, "\004", "ABB"; - 1, 2, "\004", "ABBA"; - 2, 2, "\004", "AABBA"; - 2, 2, "\004", "AABBAA"; - 0, 0, "", "BB"; - 1, 0, "", "BB"; - 2, 0, "", "BB"; -] +let rfc4648_tests = + [ + ("", ""); + ("f", "Zg=="); + ("fo", "Zm8="); + ("foo", "Zm9v"); + ("foob", "Zm9vYg=="); + ("fooba", "Zm9vYmE="); + ("foobar", "Zm9vYmFy"); + ] + +let hannes_tests = + [ + ("dummy", "ZHVtbXk="); + ("dummy", "ZHVtbXk"); + ("dummy", "ZHVtbXk=="); + ("dummy", "ZHVtbXk==="); + ("dummy", "ZHVtbXk===="); + ("dummy", "ZHVtbXk====="); + ("dummy", "ZHVtbXk======"); + ] + +let php_tests = + [ + ( "πάντα χωρεῖ καὶ οὐδὲν μένει …", + "z4DOrM69z4TOsSDPh8-Jz4HOteG_liDOus6x4b22IM6_4b2QzrThvbLOvSDOvM6tzr3Otc65IOKApg" + ); + ] + +let rfc3548_tests = + [ + ("\x14\xfb\x9c\x03\xd9\x7e", "FPucA9l+"); + ("\x14\xfb\x9c\x03\xd9", "FPucA9k="); + ("\x14\xfb\x9c\x03", "FPucAw=="); + ] + +let cfcs_tests = + [ + (0, 2, "\004", "BB"); + (1, 2, "\004", "ABB"); + (1, 2, "\004", "ABBA"); + (2, 2, "\004", "AABBA"); + (2, 2, "\004", "AABBAA"); + (0, 0, "", "BB"); + (1, 0, "", "BB"); + (2, 0, "", "BB"); + ] let nocrypto_tests = - [ "\x00\x5a\x6d\x39\x76", None - ; "\x5a\x6d\x39\x76", Some "\x66\x6f\x6f" - ; "\x5a\x6d\x39\x76\x76", None - ; "\x5a\x6d\x39\x76\x76\x76", None - ; "\x5a\x6d\x39\x76\x76\x76\x76", None - ; "\x5a\x6d\x39\x76\x00", None - ; "\x5a\x6d\x39\x76\x62\x77\x3d\x3d", Some "\x66\x6f\x6f\x6f" - ; "\x5a\x6d\x39\x76\x62\x77\x3d\x3d\x00", None - ; "\x5a\x6d\x39\x76\x62\x77\x3d\x3d\x00\x01", None - ; "\x5a\x6d\x39\x76\x62\x77\x3d\x3d\x00\x01\x02", None - ; "\x5a\x6d\x39\x76\x62\x77\x3d\x3d\x00\x01\x02\x03", None - ; "\x5a\x6d\x39\x76\x62\x32\x38\x3d", Some "\x66\x6f\x6f\x6f\x6f" - ; "\x5a\x6d\x39\x76\x62\x32\x39\x76", Some "\x66\x6f\x6f\x6f\x6f\x6f" - ; "YWE=", Some "aa" - ; "YWE==", None - ; "YWE===", None - ; "YWE=====", None - ; "YWE======", None ] + [ + ("\x00\x5a\x6d\x39\x76", None); + ("\x5a\x6d\x39\x76", Some "\x66\x6f\x6f"); + ("\x5a\x6d\x39\x76\x76", None); + ("\x5a\x6d\x39\x76\x76\x76", None); + ("\x5a\x6d\x39\x76\x76\x76\x76", None); + ("\x5a\x6d\x39\x76\x00", None); + ("\x5a\x6d\x39\x76\x62\x77\x3d\x3d", Some "\x66\x6f\x6f\x6f"); + ("\x5a\x6d\x39\x76\x62\x77\x3d\x3d\x00", None); + ("\x5a\x6d\x39\x76\x62\x77\x3d\x3d\x00\x01", None); + ("\x5a\x6d\x39\x76\x62\x77\x3d\x3d\x00\x01\x02", None); + ("\x5a\x6d\x39\x76\x62\x77\x3d\x3d\x00\x01\x02\x03", None); + ("\x5a\x6d\x39\x76\x62\x32\x38\x3d", Some "\x66\x6f\x6f\x6f\x6f"); + ("\x5a\x6d\x39\x76\x62\x32\x39\x76", Some "\x66\x6f\x6f\x6f\x6f\x6f"); + ("YWE=", Some "aa"); + ("YWE==", None); + ("YWE===", None); + ("YWE=====", None); + ("YWE======", None); + ] let alphabet_size () = - List.iter (fun (name,alphabet) -> - Alcotest.(check int) (sprintf "Alphabet size %s = 64" name) - 64 (Base64.length_alphabet alphabet)) - ["default",Base64.default_alphabet; "uri_safe",Base64.uri_safe_alphabet] + List.iter + (fun (name, alphabet) -> + Alcotest.(check int) + (sprintf "Alphabet size %s = 64" name) + 64 + (Base64.length_alphabet alphabet)) + [ + ("default", Base64.default_alphabet); + ("uri_safe", Base64.uri_safe_alphabet); + ] (* Encode using OpenSSL `base64` utility *) let openssl_encode buf = - Bos.(OS.Cmd.in_string buf |> OS.Cmd.run_io (Cmd.v "base64") |> OS.Cmd.to_string ~trim:true) |> - function | Ok r -> prerr_endline r; r | Error (`Msg e) -> raise (Failure (sprintf "OpenSSL decode: %s" e)) + Bos.( + OS.Cmd.in_string buf + |> OS.Cmd.run_io (Cmd.v "base64") + |> OS.Cmd.to_string ~trim:true) + |> function + | Ok r -> + prerr_endline r ; + r + | Error (`Msg e) -> raise (Failure (sprintf "OpenSSL decode: %s" e)) (* Encode using this library *) -let lib_encode buf = - Base64.encode_exn ~pad:true buf +let lib_encode buf = Base64.encode_exn ~pad:true buf let test_rfc4648 () = - List.iter (fun (c,r) -> - (* Base64 vs openssl *) - Alcotest.(check string) (sprintf "encode %s" c) (openssl_encode c) (lib_encode c); - (* Base64 vs test cases above *) - Alcotest.(check string) (sprintf "encode rfc4648 %s" c) r (lib_encode c); - (* Base64 decode vs library *) - Alcotest.(check string) (sprintf "decode %s" r) c (Base64.decode_exn r); - ) rfc4648_tests + List.iter + (fun (c, r) -> + (* Base64 vs openssl *) + Alcotest.(check string) + (sprintf "encode %s" c) (openssl_encode c) (lib_encode c) ; + (* Base64 vs test cases above *) + Alcotest.(check string) (sprintf "encode rfc4648 %s" c) r (lib_encode c) ; + (* Base64 decode vs library *) + Alcotest.(check string) (sprintf "decode %s" r) c (Base64.decode_exn r)) + rfc4648_tests let test_rfc3548 () = - List.iter (fun (c,r) -> - (* Base64 vs openssl *) - Alcotest.(check string) (sprintf "encode %s" c) (openssl_encode c) (lib_encode c); - (* Base64 vs test cases above *) - Alcotest.(check string) (sprintf "encode rfc3548 %s" c) r (lib_encode c); - (* Base64 decode vs library *) - Alcotest.(check string) (sprintf "decode %s" r) c (Base64.decode_exn r); - ) rfc3548_tests + List.iter + (fun (c, r) -> + (* Base64 vs openssl *) + Alcotest.(check string) + (sprintf "encode %s" c) (openssl_encode c) (lib_encode c) ; + (* Base64 vs test cases above *) + Alcotest.(check string) (sprintf "encode rfc3548 %s" c) r (lib_encode c) ; + (* Base64 decode vs library *) + Alcotest.(check string) (sprintf "decode %s" r) c (Base64.decode_exn r)) + rfc3548_tests let test_hannes () = - List.iter (fun (c,r) -> - (* Base64 vs test cases above *) - Alcotest.(check string) (sprintf "decode %s" r) c (Base64.decode_exn ~pad:false r); - ) hannes_tests + List.iter + (fun (c, r) -> + (* Base64 vs test cases above *) + Alcotest.(check string) + (sprintf "decode %s" r) c + (Base64.decode_exn ~pad:false r)) + hannes_tests let test_php () = - List.iter (fun (c,r) -> - Alcotest.(check string) (sprintf "decode %s" r) c (Base64.decode_exn ~pad:false ~alphabet:Base64.uri_safe_alphabet r); - ) php_tests + List.iter + (fun (c, r) -> + Alcotest.(check string) + (sprintf "decode %s" r) c + (Base64.decode_exn ~pad:false ~alphabet:Base64.uri_safe_alphabet r)) + php_tests let test_cfcs () = - List.iter (fun (off, len, c,r) -> - Alcotest.(check string) (sprintf "decode %s" r) c (Base64.decode_exn ~pad:false ~off ~len r); - ) cfcs_tests + List.iter + (fun (off, len, c, r) -> + Alcotest.(check string) + (sprintf "decode %s" r) c + (Base64.decode_exn ~pad:false ~off ~len r)) + cfcs_tests let test_nocrypto () = - List.iter (fun (input, res) -> - let res' = match Base64.decode ~pad:true input with - | Ok v -> Some v - | Error _ -> None in - Alcotest.(check (option string)) (sprintf "decode %S" input) res' res ; - ) nocrypto_tests + List.iter + (fun (input, res) -> + let res' = + match Base64.decode ~pad:true input with + | Ok v -> Some v + | Error _ -> None in + Alcotest.(check (option string)) (sprintf "decode %S" input) res' res) + nocrypto_tests exception Malformed + exception Wrong_padding let strict_base64_rfc2045_of_string x = let decoder = Base64_rfc2045.decoder (`String x) in let res = Buffer.create 16 in - let rec go () = match Base64_rfc2045.decode decoder with - | `End -> () - | `Wrong_padding -> raise Wrong_padding - | `Malformed _ -> raise Malformed - | `Flush x -> Buffer.add_string res x ; go () - | `Await -> Alcotest.failf "Retrieve impossible case: `Await" in + let rec go () = + match Base64_rfc2045.decode decoder with + | `End -> () + | `Wrong_padding -> raise Wrong_padding + | `Malformed _ -> raise Malformed + | `Flush x -> + Buffer.add_string res x ; + go () + | `Await -> Alcotest.failf "Retrieve impossible case: `Await" in Base64_rfc2045.src decoder (Bytes.unsafe_of_string x) 0 (String.length x) ; - go () ; Buffer.contents res + go () ; + Buffer.contents res let relaxed_base64_rfc2045_of_string x = let decoder = Base64_rfc2045.decoder (`String x) in let res = Buffer.create 16 in - let rec go () = match Base64_rfc2045.decode decoder with - | `End -> () - | `Wrong_padding -> go () - | `Malformed _ -> go () - | `Flush x -> Buffer.add_string res x ; go () - | `Await -> Alcotest.failf "Retrieve impossible case: `Await" in + let rec go () = + match Base64_rfc2045.decode decoder with + | `End -> () + | `Wrong_padding -> go () + | `Malformed _ -> go () + | `Flush x -> + Buffer.add_string res x ; + go () + | `Await -> Alcotest.failf "Retrieve impossible case: `Await" in Base64_rfc2045.src decoder (Bytes.unsafe_of_string x) 0 (String.length x) ; - go () ; Buffer.contents res + go () ; + Buffer.contents res let test_strict_rfc2045 = - [ "c2FsdXQgbGVzIGNvcGFpbnMgZmF1dCBhYnNvbHVtZW50IHF1ZSBqZSBkw6lwYXNzZSBsZXMgODAg\r\n\ - Y2hhcmFjdGVycyBwb3VyIHZvaXIgc2kgbW9uIGVuY29kZXIgZml0cyBiaWVuIGRhbnMgbGVzIGxp\r\n\ - bWl0ZXMgZGUgbGEgUkZDIDIwNDUgLi4u", - "salut les copains faut absolument que je dépasse les 80 characters pour voir si \ - mon encoder fits bien dans les limites de la RFC 2045 ..." - ; "", "" - ; "Zg==", "f" - ; "Zm8=", "fo" - ; "Zm9v", "foo" - ; "Zm9vYg==", "foob" - ; "Zm9vYmE=", "fooba" - ; "Zm9vYmFy", "foobar" ] + [ + ( "c2FsdXQgbGVzIGNvcGFpbnMgZmF1dCBhYnNvbHVtZW50IHF1ZSBqZSBkw6lwYXNzZSBsZXMgODAg\r\n\ + Y2hhcmFjdGVycyBwb3VyIHZvaXIgc2kgbW9uIGVuY29kZXIgZml0cyBiaWVuIGRhbnMgbGVzIGxp\r\n\ + bWl0ZXMgZGUgbGEgUkZDIDIwNDUgLi4u", + "salut les copains faut absolument que je dépasse les 80 characters \ + pour voir si mon encoder fits bien dans les limites de la RFC 2045 ..." + ); + ("", ""); + ("Zg==", "f"); + ("Zm8=", "fo"); + ("Zm9v", "foo"); + ("Zm9vYg==", "foob"); + ("Zm9vYmE=", "fooba"); + ("Zm9vYmFy", "foobar"); + ] let test_relaxed_rfc2045 = - [ "Zg", "f" - ; "Zm\n8", "fo" - ; "Zm\r9v", "foo" - ; "Zm9 vYg", "foob" - ; "Zm9\r\n vYmE", "fooba" - ; "Zm9évYmFy", "foobar" ] + [ + ("Zg", "f"); + ("Zm\n8", "fo"); + ("Zm\r9v", "foo"); + ("Zm9 vYg", "foob"); + ("Zm9\r\n vYmE", "fooba"); + ("Zm9évYmFy", "foobar"); + ] let strict_base64_rfc2045_to_string x = let res = Buffer.create 16 in let encoder = Base64_rfc2045.encoder (`Buffer res) in String.iter - (fun chr -> match Base64_rfc2045.encode encoder (`Char chr) with + (fun chr -> + match Base64_rfc2045.encode encoder (`Char chr) with | `Ok -> () - | `Partial -> Alcotest.failf "Retrieve impossible case for (`Char %02x): `Partial" (Char.code chr)) + | `Partial -> + Alcotest.failf "Retrieve impossible case for (`Char %02x): `Partial" + (Char.code chr)) x ; match Base64_rfc2045.encode encoder `End with | `Ok -> Buffer.contents res | `Partial -> Alcotest.fail "Retrieve impossible case for `End: `Partial" let test_strict_with_malformed_input_rfc2045 = - List.mapi (fun i (has, _) -> - Alcotest.test_case (Fmt.strf "strict rfc2045 - %02d" i) `Quick @@ fun () -> + List.mapi + (fun i (has, _) -> + Alcotest.test_case (Fmt.strf "strict rfc2045 - %02d" i) `Quick + @@ fun () -> try let _ = strict_base64_rfc2045_of_string has in Alcotest.failf "Strict parser valids malformed input: %S" has - with Malformed | Wrong_padding -> () ) + with Malformed | Wrong_padding -> ()) test_relaxed_rfc2045 let test_strict_rfc2045 = - List.mapi (fun i (has, expect) -> - Alcotest.test_case (Fmt.strf "strict rfc2045 - %02d" i) `Quick @@ fun () -> + List.mapi + (fun i (has, expect) -> + Alcotest.test_case (Fmt.strf "strict rfc2045 - %02d" i) `Quick + @@ fun () -> try let res0 = strict_base64_rfc2045_of_string has in let res1 = strict_base64_rfc2045_to_string res0 in @@ -234,26 +291,32 @@ let test_strict_rfc2045 = test_strict_rfc2045 let test_relaxed_rfc2045 = - List.mapi (fun i (has, expect) -> - Alcotest.test_case (Fmt.strf "relaxed rfc2045 - %02d" i) `Quick @@ fun () -> + List.mapi + (fun i (has, expect) -> + Alcotest.test_case (Fmt.strf "relaxed rfc2045 - %02d" i) `Quick + @@ fun () -> let res0 = relaxed_base64_rfc2045_of_string has in Alcotest.(check string) "decode(x)" res0 expect) test_relaxed_rfc2045 -let test_invariants = [ "Alphabet size", `Quick, alphabet_size ] -let test_codec = [ "RFC4648 test vectors", `Quick, test_rfc4648 - ; "RFC3548 test vectors", `Quick, test_rfc3548 - ; "Hannes test vectors", `Quick, test_hannes - ; "Cfcs test vectors", `Quick, test_cfcs - ; "PHP test vectors", `Quick, test_php - ; "Nocrypto test vectors", `Quick, test_nocrypto ] +let test_invariants = [ ("Alphabet size", `Quick, alphabet_size) ] -let () = - Alcotest.run "Base64" [ - "invariants", test_invariants; - "codec", test_codec; - "rfc2045 (0)", test_strict_rfc2045; - "rfc2045 (1)", test_strict_with_malformed_input_rfc2045; - "rfc2045 (2)", test_relaxed_rfc2045; +let test_codec = + [ + ("RFC4648 test vectors", `Quick, test_rfc4648); + ("RFC3548 test vectors", `Quick, test_rfc3548); + ("Hannes test vectors", `Quick, test_hannes); + ("Cfcs test vectors", `Quick, test_cfcs); + ("PHP test vectors", `Quick, test_php); + ("Nocrypto test vectors", `Quick, test_nocrypto); ] +let () = + Alcotest.run "Base64" + [ + ("invariants", test_invariants); + ("codec", test_codec); + ("rfc2045 (0)", test_strict_rfc2045); + ("rfc2045 (1)", test_strict_with_malformed_input_rfc2045); + ("rfc2045 (2)", test_relaxed_rfc2045); + ]