diff --git a/src/base64_rfc2045.ml b/src/base64_rfc2045.ml index b7fc7e9..8018fe7 100644 --- a/src/base64_rfc2045.ml +++ b/src/base64_rfc2045.ml @@ -24,8 +24,8 @@ let invalid_arg fmt = Format.ksprintf (fun s -> invalid_arg s) fmt let invalid_bounds off len = invalid_arg "Invalid bounds (off: %d, len: %d)" off len -let malformed source off pos len = - `Malformed (Bytes.sub_string source (off + pos) len) +let malformed chr = + `Malformed (String.make 1 chr) let unsafe_byte source off pos = Bytes.unsafe_get source (off + pos) let unsafe_blit = Bytes.unsafe_blit @@ -54,14 +54,7 @@ let r_repr ({quantum; size; _} as state) chr = unsafe_set_chr state.buffer 2 (unsafe_chr ((quantum lsl 6) lor code land 255)) ; flush state - | _ -> malformed (Bytes.make 1 chr) 0 0 1 - -let r_crlf source off len = - (* assert (0 <= off && 0 <= len && off + len <= String.length source); *) - (* assert (len = 2); *) - match Bytes.sub_string source off len with - | "\r\n" -> `Line_break - | _ -> malformed source off 0 len + | _ -> malformed chr type src = [`Channel of in_channel | `String of string | `Manual] @@ -78,9 +71,6 @@ type decoder = ; mutable i_pos: int ; mutable i_len: int ; mutable s: state - ; h: Bytes.t - ; mutable h_len: int - ; mutable h_need: int ; mutable padding: int ; mutable unsafe: bool ; mutable byte_count: int @@ -127,33 +117,6 @@ let ret k v byte_count decoder = if decoder.limit_count > 78 then dangerous decoder true ; decoder.pp decoder v -[@@@warning "-32"] - -let is_b64 = function - | 'A' .. 'Z' | 'a' .. 'z' | '0' .. '9' | '+' | '/' -> true - | _ -> false - -let t_need decoder need = - decoder.h_len <- 0 ; - decoder.h_need <- need - -let rec t_fill k decoder = - let blit decoder len = - unsafe_blit decoder.i - (decoder.i_off + decoder.i_pos) - decoder.h decoder.h_len len ; - decoder.i_pos <- decoder.i_pos + len ; - decoder.h_len <- decoder.h_len + len - in - let rem = i_rem decoder in - if rem < 0 (* end of input *) then k decoder - else - let need = decoder.h_need - decoder.h_len in - if rem < need then ( - blit decoder rem ; - refill (t_fill k) decoder ) - else ( blit decoder need ; k decoder ) - type flush_and_malformed = [`Flush of state | `Malformed of string] let padding {size; _} padding = @@ -164,15 +127,7 @@ let padding {size; _} padding = | 3, 1 -> true | _ -> false -let rec t_crlf decoder = - if decoder.h_len < decoder.h_need then - ret decode_base64 - (malformed decoder.h 0 0 decoder.h_len) - decoder.h_len decoder - else - ret decode_base64 (r_crlf decoder.h 0 decoder.h_len) decoder.h_len decoder - -and t_flush {quantum; size; buffer} = +let t_flush {quantum; size; buffer} = match size with | 0 | 1 -> `Flush {quantum; size; buffer= Bytes.empty} | 2 -> @@ -186,9 +141,13 @@ and t_flush {quantum; size; buffer} = unsafe_set_chr buffer 0 (unsafe_chr ((quantum lsr 8) land 255)) ; unsafe_set_chr buffer 1 (unsafe_chr (quantum land 255)) ; `Flush {quantum; size; buffer= Bytes.sub buffer 0 2} - | _ -> malformed buffer 0 0 3 + | _ -> assert false (* this branch is impossible, size can only ever be in the range [0..3]. *) + +let wrong_padding decoder = + let k _ = `End in + decoder.k <- k ; `Wrong_padding -and t_decode_base64 chr decoder = +let rec t_decode_base64 chr decoder = if decoder.padding = 0 then let rec go pos = function | `Continue state -> @@ -197,22 +156,26 @@ and t_decode_base64 chr decoder = match unsafe_byte decoder.i decoder.i_off (decoder.i_pos + pos) with | ('A' .. 'Z' | 'a' .. 'z' | '0' .. '9' | '+' | '/') as chr -> go (succ pos) (r_repr state chr) | '=' -> - decoder.i_pos <- decoder.i_pos + pos ; + decoder.padding <- decoder.padding + 1 ; + decoder.i_pos <- decoder.i_pos + pos + 1 ; decoder.s <- state ; - ret decode_base64 `Padding pos decoder + ret decode_base64 `Padding (pos+1) decoder | ' ' | '\t' -> - decoder.i_pos <- decoder.i_pos + pos ; + decoder.i_pos <- decoder.i_pos + pos + 1 ; decoder.s <- state ; - ret decode_base64 `Wsp pos decoder + ret decode_base64 `Wsp (pos + 1) decoder | '\r' -> - decoder.i_pos <- decoder.i_pos + pos ; + decoder.i_pos <- decoder.i_pos + pos + 1 ; + decoder.s <- state ; + decode_base64_lf_after_cr decoder + | chr -> + decoder.i_pos <- decoder.i_pos + pos + 1 ; decoder.s <- state ; - t_need decoder 2 ; - t_fill t_crlf decoder - | chr -> malformed (Bytes.make 1 chr) 0 0 1 + ret decode_base64 (malformed chr) (pos+1) decoder ) else ( decoder.i_pos <- decoder.i_pos + pos ; decoder.byte_count <- decoder.byte_count + pos ; + decoder.limit_count <- decoder.limit_count + pos ; decoder.s <- state ; refill decode_base64 decoder ) | #flush_and_malformed as v -> @@ -220,7 +183,22 @@ and t_decode_base64 chr decoder = ret decode_base64 v pos decoder in go 1 (r_repr decoder.s chr) - else malformed (Bytes.make 1 chr) 0 0 1 + else ( + decoder.i_pos <- decoder.i_pos + 1 ; + ret decode_base64 (malformed chr) 1 decoder) + +and decode_base64_lf_after_cr decoder = + let rem = i_rem decoder in + if rem < 0 then + ret decode_base64 (malformed '\r') 1 decoder + else if rem = 0 then refill decode_base64_lf_after_cr decoder + else + match unsafe_byte decoder.i decoder.i_off decoder.i_pos with + | '\n' -> + decoder.i_pos <- decoder.i_pos + 1 ; + ret decode_base64 `Line_break 2 decoder + | _ -> + ret decode_base64 (malformed '\r') 1 decoder and decode_base64 decoder = let rem = i_rem decoder in @@ -228,7 +206,7 @@ and decode_base64 decoder = if rem < 0 then ret (fun decoder -> - if padding decoder.s decoder.padding then `End else `Wrong_padding ) + if padding decoder.s decoder.padding then `End else wrong_padding decoder ) (t_flush decoder.s) 0 decoder else refill decode_base64 decoder else @@ -242,8 +220,12 @@ and decode_base64 decoder = | ' ' | '\t' -> decoder.i_pos <- decoder.i_pos + 1 ; ret decode_base64 `Wsp 1 decoder - | '\r' -> t_need decoder 2 ; t_fill t_crlf decoder - | chr -> malformed (Bytes.make 1 chr) 0 0 1 + | '\r' -> + decoder.i_pos <- decoder.i_pos + 1 ; + decode_base64_lf_after_cr decoder + | chr -> + decoder.i_pos <- decoder.i_pos + 1 ; + ret decode_base64 (malformed chr) 1 decoder let pp_base64 decoder = function | `Line_break -> reset decoder ; decoder.k decoder @@ -268,9 +250,6 @@ let decoder src = ; i_len ; i ; s= {quantum= 0; size= 0; buffer= Bytes.create 3} - ; h= Bytes.create 2 - ; h_len= 0 - ; h_need= 0 ; padding= 0 ; unsafe= false ; byte_count= 0 diff --git a/test/dune b/test/dune index cf4384b..8a0c40b 100644 --- a/test/dune +++ b/test/dune @@ -1,6 +1,6 @@ (executable (name test) - (libraries base64 rresult alcotest bos)) + (libraries base64 base64.rfc2045 rresult alcotest bos)) (alias (name runtest) diff --git a/test/test.ml b/test/test.ml index 19bc36f..a6236a3 100644 --- a/test/test.ml +++ b/test/test.ml @@ -120,7 +120,97 @@ let test_cfcs () = Alcotest.(check string) (sprintf "decode %s" r) c (Base64.decode_exn ~pad:false ~off ~len r); ) cfcs_tests - +exception Malformed +exception Wrong_padding + +let strict_base64_rfc2045_of_string x = + let decoder = Base64_rfc2045.decoder (`String x) in + let res = Buffer.create 16 in + + let rec go () = match Base64_rfc2045.decode decoder with + | `End -> () + | `Wrong_padding -> raise Wrong_padding + | `Malformed _ -> raise Malformed + | `Flush x -> Buffer.add_string res x ; go () + | `Await -> Alcotest.failf "Retrieve impossible case: `Await" in + + Base64_rfc2045.src decoder (Bytes.unsafe_of_string x) 0 (String.length x) ; + go () ; Buffer.contents res + +let relaxed_base64_rfc2045_of_string x = + let decoder = Base64_rfc2045.decoder (`String x) in + let res = Buffer.create 16 in + + let rec go () = match Base64_rfc2045.decode decoder with + | `End -> () + | `Wrong_padding -> go () + | `Malformed _ -> go () + | `Flush x -> Buffer.add_string res x ; go () + | `Await -> Alcotest.failf "Retrieve impossible case: `Await" in + + Base64_rfc2045.src decoder (Bytes.unsafe_of_string x) 0 (String.length x) ; + go () ; Buffer.contents res + +let test_strict_rfc2045 = + [ "c2FsdXQgbGVzIGNvcGFpbnMgZmF1dCBhYnNvbHVtZW50IHF1ZSBqZSBkw6lwYXNzZSBsZXMgODAg\r\n\ + Y2hhcmFjdGVycyBwb3VyIHZvaXIgc2kgbW9uIGVuY29kZXIgZml0cyBiaWVuIGRhbnMgbGVzIGxp\r\n\ + bWl0ZXMgZGUgbGEgUkZDIDIwNDUgLi4u", + "salut les copains faut absolument que je dépasse les 80 characters pour voir si \ + mon encoder fits bien dans les limites de la RFC 2045 ..." + ; "", "" + ; "Zg==", "f" + ; "Zm8=", "fo" + ; "Zm9v", "foo" + ; "Zm9vYg==", "foob" + ; "Zm9vYmE=", "fooba" + ; "Zm9vYmFy", "foobar" ] + +let test_relaxed_rfc2045 = + [ "Zg", "f" + ; "Zm\n8", "fo" + ; "Zm\r9v", "foo" + ; "Zm9 vYg", "foob" + ; "Zm9\r\n vYmE", "fooba" + ; "Zm9évYmFy", "foobar" ] + +let strict_base64_rfc2045_to_string x = + let res = Buffer.create 16 in + let encoder = Base64_rfc2045.encoder (`Buffer res) in + String.iter + (fun chr -> match Base64_rfc2045.encode encoder (`Char chr) with + | `Ok -> () + | `Partial -> Alcotest.failf "Retrieve impossible case for (`Char %02x): `Partial" (Char.code chr)) + x ; + match Base64_rfc2045.encode encoder `End with + | `Ok -> Buffer.contents res + | `Partial -> Alcotest.fail "Retrieve impossible case for `End: `Partial" + +let test_strict_with_malformed_input_rfc2045 = + List.mapi (fun i (has, _) -> + Alcotest.test_case (Fmt.strf "strict rfc2045 - %02d" i) `Quick @@ fun () -> + try + let _ = strict_base64_rfc2045_of_string has in + Alcotest.failf "Strict parser valids malformed input: %S" has + with Malformed | Wrong_padding -> () ) + test_relaxed_rfc2045 + +let test_strict_rfc2045 = + List.mapi (fun i (has, expect) -> + Alcotest.test_case (Fmt.strf "strict rfc2045 - %02d" i) `Quick @@ fun () -> + try + let res0 = strict_base64_rfc2045_of_string has in + let res1 = strict_base64_rfc2045_to_string res0 in + Alcotest.(check string) "encode(decode(x)) = x" res1 has ; + Alcotest.(check string) "decode(x)" res0 expect + with Malformed | Wrong_padding -> Alcotest.failf "Invalid input %S" has) + test_strict_rfc2045 + +let test_relaxed_rfc2045 = + List.mapi (fun i (has, expect) -> + Alcotest.test_case (Fmt.strf "relaxed rfc2045 - %02d" i) `Quick @@ fun () -> + let res0 = relaxed_base64_rfc2045_of_string has in + Alcotest.(check string) "decode(x)" res0 expect) + test_relaxed_rfc2045 let test_invariants = [ "Alphabet size", `Quick, alphabet_size ] let test_codec = [ "RFC4648 test vectors", `Quick, test_rfc4648 @@ -133,5 +223,8 @@ let () = Alcotest.run "Base64" [ "invariants", test_invariants; "codec", test_codec; + "rfc2045", test_strict_rfc2045; + "rfc2045", test_strict_with_malformed_input_rfc2045; + "rfc2045", test_relaxed_rfc2045; ]