diff --git a/cohttp-lwt-unix/src/io.ml b/cohttp-lwt-unix/src/io.ml index 6813fb8229..b5579d6b24 100644 --- a/cohttp-lwt-unix/src/io.ml +++ b/cohttp-lwt-unix/src/io.ml @@ -14,7 +14,13 @@ * }}}*) +exception IO_error of exn + let () = + Printexc.register_printer (function + | IO_error e -> Some ("IO error: " ^ Printexc.to_string e) + | _ -> None + ); if Sys.os_type <> "Win32" then Sys.(set_signal sigpipe Signal_ignore); @@ -29,17 +35,25 @@ type conn = Conduit_lwt_unix.flow let src = Logs.Src.create "cohttp.lwt.io" ~doc:"Cohttp Lwt IO module" module Log = (val Logs.src_log src : Logs.LOG) -let with_closed_handler f ~if_closed = +let wrap_read f ~if_closed = (* TODO Use [Lwt_io.is_closed] when available: https://github.com/ocsigen/lwt/pull/635 *) Lwt.catch f (function - | Lwt_io.Channel_closed _ -> Lwt.return if_closed - | exn -> raise exn + | Lwt_io.Channel_closed _ -> Lwt.return if_closed + | Unix.Unix_error _ as e -> Lwt.fail (IO_error e) + | exn -> raise exn + ) + +let wrap_write f = + Lwt.catch f + (function + | Unix.Unix_error _ as e -> Lwt.fail (IO_error e) + | exn -> raise exn ) let read_line ic = - with_closed_handler ~if_closed:None + wrap_read ~if_closed:None (fun () -> Lwt_io.read_line_opt ic >>= function | None -> @@ -52,7 +66,7 @@ let read_line ic = let read ic count = let count = min count Sys.max_string_length in - with_closed_handler ~if_closed:"" + wrap_read ~if_closed:"" (fun () -> Lwt_io.read ~count ic >>= fun buf -> Log.debug (fun f -> f "<<<[%d] %s" count buf); @@ -60,8 +74,21 @@ let read ic count = ) let write oc buf = + wrap_write @@ fun () -> Log.debug (fun f -> f ">>> %s" (String.trim buf)); Lwt_io.write oc buf let flush oc = + wrap_write @@ fun () -> Lwt_io.flush oc + +type error = exn + +let catch f = + Lwt.try_bind f Lwt.return_ok + (function + | IO_error e -> Lwt.return_error e + | ex -> Lwt.fail ex + ) + +let pp_error = Fmt.exn diff --git a/cohttp-lwt-unix/src/io.mli b/cohttp-lwt-unix/src/io.mli index 397436491c..1e90ce1b11 100644 --- a/cohttp-lwt-unix/src/io.mli +++ b/cohttp-lwt-unix/src/io.mli @@ -14,9 +14,8 @@ * }}}*) -include Cohttp.S.IO - with type 'a t = 'a Lwt.t - and type ic = Lwt_io.input_channel +include Cohttp_lwt.S.IO + with type ic = Lwt_io.input_channel and type oc = Lwt_io.output_channel and type conn = Conduit_lwt_unix.flow - + and type error = exn diff --git a/cohttp-lwt-unix/src/net.mli b/cohttp-lwt-unix/src/net.mli index ba157cca30..ca5bc291d1 100644 --- a/cohttp-lwt-unix/src/net.mli +++ b/cohttp-lwt-unix/src/net.mli @@ -16,11 +16,7 @@ (** Basic satisfaction of {! Cohttp_lwt.Net } *) -module IO : Cohttp.S.IO - with type 'a t = 'a Lwt.t - and type ic = Lwt_io.input_channel - and type oc = Lwt_io.output_channel - and type conn = Conduit_lwt_unix.flow +module IO = Io type ctx = { ctx : Conduit_lwt_unix.ctx; diff --git a/cohttp-lwt-unix/test/test_sanity.ml b/cohttp-lwt-unix/test/test_sanity.ml index 20b4740e20..e270da6b34 100644 --- a/cohttp-lwt-unix/test/test_sanity.ml +++ b/cohttp-lwt-unix/test/test_sanity.ml @@ -6,6 +6,12 @@ open Cohttp_lwt_unix_test module Body = Cohttp_lwt.Body +module IO = Cohttp_lwt_unix.IO +module Request = struct + include Cohttp.Request + include (Make(IO) : module type of Make(IO) with type t := t) +end + let message = "Hello sanity!" let chunk_body = ["one"; ""; " "; "bar"; ""] @@ -15,6 +21,8 @@ let leak_repeat = 1024 let () = Debug.activate_debug () let () = Logs.set_level (Some Info) +let cond = Lwt_condition.create () + let server = List.map const [ (* t *) Server.respond_string ~status:`OK ~body:message (); @@ -74,8 +82,29 @@ let server = ) )) ] + @ (* client_close *) + [ + fun _ _ -> + let ready = Lwt_condition.wait cond in + let i = ref 0 in + let stream = Lwt_stream.from (fun () -> + ready >|= fun () -> + incr i; + if !i > 1000 then failwith "Connection should have failed by now!"; + Some (String.make 4096 'X') + ) + in + Lwt.return (`Response (Cohttp.Response.make ~status:`OK (), `Stream stream)) + ] |> response_sequence +let check_logs test () = + let old = Logs.(warn_count () + err_count ()) in + test () >|= fun () -> + let new_errs = Logs.(warn_count () + err_count ()) - old in + if new_errs > 0 then + Fmt.failwith "Test produced %d log messages at level >= warn" new_errs + let ts = Cohttp_lwt_unix_test.test_server_s server begin fun uri -> let t () = @@ -179,15 +208,29 @@ let ts = Body.to_string body >|= fun body -> assert_equal ~printer "expert 2" body in - [ "sanity test", t - ; "empty chunk test", empty_chunk - ; "pipelined chunk test", pipelined_chunk - ; "no body when response is not modified", not_modified_has_no_body - ; "pipelined with interleaving requests", pipelined_interleave - ; "massive chunked", massive_chunked - ; "unreadable file returns 500", unreadable_file_500 - ; "no leaks on requests", test_no_leak - ; "expert response", expert_pipelined + let client_close () = + Cohttp_lwt_unix.Net.(connect_uri ~ctx:default_ctx) uri >>= fun (_conn, ic, oc) -> + let req = Cohttp.Request.make_for_client ~chunked:false `GET (Uri.with_path uri "/test.html") in + Request.write (fun _writer -> Lwt.return_unit) req oc + >>= fun () -> + Response.read ic >>= function + | `Eof | `Invalid _ -> assert false + | `Ok rsp -> + assert_equal ~printer:Cohttp.Code.string_of_status `OK (Cohttp.Response.status rsp); + Cohttp_lwt_unix.Net.close ic oc; + Lwt_condition.broadcast cond (); + Lwt.pause () + in + [ "sanity test", check_logs t + ; "empty chunk test", check_logs empty_chunk + ; "pipelined chunk test", check_logs pipelined_chunk + ; "no body when response is not modified", check_logs not_modified_has_no_body + ; "pipelined with interleaving requests", check_logs pipelined_interleave + ; "massive chunked", check_logs massive_chunked + ; "unreadable file returns 500", unreadable_file_500 + ; "no leaks on requests", check_logs test_no_leak + ; "expert response", check_logs expert_pipelined + ; "client_close", check_logs client_close ] end diff --git a/cohttp-lwt/src/cohttp_lwt.ml b/cohttp-lwt/src/cohttp_lwt.ml index 91224b3655..fe7c8c2951 100644 --- a/cohttp-lwt/src/cohttp_lwt.ml +++ b/cohttp-lwt/src/cohttp_lwt.ml @@ -14,7 +14,7 @@ * }}}*) -module type IO = S.IO with type 'a t = 'a Lwt.t +module type IO = S.IO module Request = Cohttp.Request module Response = Cohttp.Response diff --git a/cohttp-lwt/src/s.ml b/cohttp-lwt/src/s.ml index ec7dcf433e..f9a3d1500d 100644 --- a/cohttp-lwt/src/s.ml +++ b/cohttp-lwt/src/s.ml @@ -4,7 +4,17 @@ functors must be instantiated by an implementation that provides a concrete IO monad. *) -module type IO = Cohttp.S.IO with type 'a t = 'a Lwt.t +module type IO = sig + include Cohttp.S.IO with type 'a t = 'a Lwt.t + + type error + + val catch : (unit -> 'a t) -> ('a, error) result t + (** [catch f] is [f () >|= Result.ok], unless [f] fails with an IO error, + in which case it returns the error. *) + + val pp_error : Format.formatter -> error -> unit +end (** The IO module is specialized for the [Lwt] monad. *) (** The [Net] module type defines how to connect to a remote node diff --git a/cohttp-lwt/src/server.ml b/cohttp-lwt/src/server.ml index 4738302f9f..b890daa99b 100644 --- a/cohttp-lwt/src/server.ml +++ b/cohttp-lwt/src/server.ml @@ -112,7 +112,7 @@ module Make(IO:S.IO) = struct (function | Out_of_memory -> Lwt.fail Out_of_memory | exn -> - Log.err (fun f -> f "Error handling %a: %s\n%!" Request.pp_hum req (Printexc.to_string exn)); + Log.err (fun f -> f "Error handling %a: %s" Request.pp_hum req (Printexc.to_string exn)); respond_error ~body:"Internal Server Error" () >|= fun rsp -> `Response rsp )) @@ -147,7 +147,12 @@ module Make(IO:S.IO) = struct let conn_closed () = spec.conn_closed (io_id,conn_id) in Lwt.finalize (fun () -> - handle_client ic oc (io_id,conn_id) spec.callback + IO.catch (fun () -> handle_client ic oc (io_id,conn_id) spec.callback) + >>= function + | Ok () -> Lwt.return_unit + | Error e -> + Log.info (fun m -> m "IO error while handling client: %a" IO.pp_error e); + Lwt.return_unit ) (fun () -> (* Clean up resources when the response stream terminates and call diff --git a/cohttp-mirage/src/io.ml b/cohttp-mirage/src/io.ml index 5becfac19a..3d7c33d03a 100644 --- a/cohttp-mirage/src/io.ml +++ b/cohttp-mirage/src/io.ml @@ -21,38 +21,61 @@ open Lwt.Infix module Make (Channel: Mirage_channel_lwt.S) = struct + type error = + | Read_error of Channel.error + | Write_error of Channel.write_error + + let pp_error f = function + | Read_error e -> Channel.pp_error f e + | Write_error e -> Channel.pp_write_error f e + type 'a t = 'a Lwt.t type ic = Channel.t type oc = Channel.t type conn = Channel.flow - let failf fmt = Fmt.kstrf Lwt.fail_with fmt + exception Read_exn of Channel.error + exception Write_exn of Channel.write_error + + let () = + Printexc.register_printer (function + | Read_exn e -> Some (Format.asprintf "IO read error: %a" Channel.pp_error e) + | Write_exn e -> Some (Format.asprintf "IO write error: %a" Channel.pp_write_error e) + | _ -> None + ) let read_line ic = Channel.read_line ic >>= function | Ok (`Data []) -> Lwt.return_none | Ok `Eof -> Lwt.return_none - | Ok (`Data bufs) -> Lwt.return (Some (Cstruct.copyv bufs)) - | Error e -> failf "Flow error: %a" Channel.pp_error e + | Ok (`Data bufs) -> Lwt.return_some (Cstruct.copyv bufs) + | Error e -> Lwt.fail (Read_exn e) let read ic len = Channel.read_some ~len ic >>= function | Ok (`Data buf) -> Lwt.return (Cstruct.to_string buf) | Ok `Eof -> Lwt.return "" - | Error e -> failf "Flow error: %a" Channel.pp_error e + | Error e -> Lwt.fail (Read_exn e) let write oc buf = Channel.write_string oc buf 0 (String.length buf); Channel.flush oc >>= function | Ok () -> Lwt.return_unit | Error `Closed -> Lwt.fail_with "Trying to write on closed channel" - | Error e -> failf "Flow error: %a" Channel.pp_write_error e + | Error e -> Lwt.fail (Write_exn e) let flush _ = (* NOOP since we flush in the normal writer functions above *) Lwt.return_unit - let (>>= ) = Lwt.( >>= ) + let (>>= ) = Lwt.( >>= ) let return = Lwt.return + let catch f = + Lwt.try_bind f Lwt.return_ok + (function + | Read_exn e -> Lwt.return_error (Read_error e) + | Write_exn e -> Lwt.return_error (Write_error e) + | ex -> Lwt.fail ex + ) end diff --git a/cohttp-mirage/src/io.mli b/cohttp-mirage/src/io.mli index 0a15e6f8e5..b7b6654218 100644 --- a/cohttp-mirage/src/io.mli +++ b/cohttp-mirage/src/io.mli @@ -19,8 +19,7 @@ (** Cohttp IO implementation using Mirage channels. *) -module Make (Channel: Mirage_channel_lwt.S) : Cohttp.S.IO - with type 'a t = 'a Lwt.t - and type ic = Channel.t +module Make (Channel: Mirage_channel_lwt.S) : Cohttp_lwt.S.IO + with type ic = Channel.t and type oc = Channel.t and type conn = Channel.flow