Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better handling of IO errors #669

Merged
merged 3 commits into from
Jul 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions cohttp-lwt-unix/src/io.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
*
}}}*)

exception IO_error of exn

let () =
Printexc.register_printer (function
| IO_error e -> Some ("IO error: " ^ Printexc.to_string e)
| _ -> None
);
if Sys.os_type <> "Win32" then
Sys.(set_signal sigpipe Signal_ignore);

Expand All @@ -29,17 +35,25 @@ type conn = Conduit_lwt_unix.flow
let src = Logs.Src.create "cohttp.lwt.io" ~doc:"Cohttp Lwt IO module"
module Log = (val Logs.src_log src : Logs.LOG)

let with_closed_handler f ~if_closed =
let wrap_read f ~if_closed =
(* TODO Use [Lwt_io.is_closed] when available:
https://github.com/ocsigen/lwt/pull/635 *)
avsm marked this conversation as resolved.
Show resolved Hide resolved
Lwt.catch f
(function
| Lwt_io.Channel_closed _ -> Lwt.return if_closed
| exn -> raise exn
| Lwt_io.Channel_closed _ -> Lwt.return if_closed
| Unix.Unix_error _ as e -> Lwt.fail (IO_error e)
| exn -> raise exn
)

let wrap_write f =
Lwt.catch f
(function
| Unix.Unix_error _ as e -> Lwt.fail (IO_error e)
| exn -> raise exn
)

let read_line ic =
with_closed_handler ~if_closed:None
wrap_read ~if_closed:None
(fun () ->
Lwt_io.read_line_opt ic >>= function
| None ->
Expand All @@ -52,16 +66,29 @@ let read_line ic =

let read ic count =
let count = min count Sys.max_string_length in
with_closed_handler ~if_closed:""
wrap_read ~if_closed:""
(fun () ->
Lwt_io.read ~count ic >>= fun buf ->
Log.debug (fun f -> f "<<<[%d] %s" count buf);
Lwt.return buf
)

let write oc buf =
wrap_write @@ fun () ->
Log.debug (fun f -> f ">>> %s" (String.trim buf));
Lwt_io.write oc buf

let flush oc =
wrap_write @@ fun () ->
Lwt_io.flush oc

type error = exn

let catch f =
Lwt.try_bind f Lwt.return_ok
(function
| IO_error e -> Lwt.return_error e
| ex -> Lwt.fail ex
)

let pp_error = Fmt.exn
7 changes: 3 additions & 4 deletions cohttp-lwt-unix/src/io.mli
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
*
}}}*)

include Cohttp.S.IO
with type 'a t = 'a Lwt.t
and type ic = Lwt_io.input_channel
include Cohttp_lwt.S.IO
with type ic = Lwt_io.input_channel
and type oc = Lwt_io.output_channel
and type conn = Conduit_lwt_unix.flow

and type error = exn
6 changes: 1 addition & 5 deletions cohttp-lwt-unix/src/net.mli
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@

(** Basic satisfaction of {! Cohttp_lwt.Net } *)

module IO : Cohttp.S.IO
with type 'a t = 'a Lwt.t
and type ic = Lwt_io.input_channel
and type oc = Lwt_io.output_channel
and type conn = Conduit_lwt_unix.flow
module IO = Io

type ctx = {
ctx : Conduit_lwt_unix.ctx;
Expand Down
61 changes: 52 additions & 9 deletions cohttp-lwt-unix/test/test_sanity.ml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ open Cohttp_lwt_unix_test

module Body = Cohttp_lwt.Body

module IO = Cohttp_lwt_unix.IO
module Request = struct
include Cohttp.Request
include (Make(IO) : module type of Make(IO) with type t := t)
end

let message = "Hello sanity!"

let chunk_body = ["one"; ""; " "; "bar"; ""]
Expand All @@ -15,6 +21,8 @@ let leak_repeat = 1024
let () = Debug.activate_debug ()
let () = Logs.set_level (Some Info)

let cond = Lwt_condition.create ()

let server =
List.map const [ (* t *)
Server.respond_string ~status:`OK ~body:message ();
Expand Down Expand Up @@ -74,8 +82,29 @@ let server =
)
))
]
@ (* client_close *)
[
fun _ _ ->
let ready = Lwt_condition.wait cond in
let i = ref 0 in
let stream = Lwt_stream.from (fun () ->
ready >|= fun () ->
incr i;
if !i > 1000 then failwith "Connection should have failed by now!";
Some (String.make 4096 'X')
)
in
Lwt.return (`Response (Cohttp.Response.make ~status:`OK (), `Stream stream))
]
|> response_sequence

let check_logs test () =
let old = Logs.(warn_count () + err_count ()) in
test () >|= fun () ->
let new_errs = Logs.(warn_count () + err_count ()) - old in
if new_errs > 0 then
Fmt.failwith "Test produced %d log messages at level >= warn" new_errs

let ts =
Cohttp_lwt_unix_test.test_server_s server begin fun uri ->
let t () =
Expand Down Expand Up @@ -179,15 +208,29 @@ let ts =
Body.to_string body >|= fun body ->
assert_equal ~printer "expert 2" body
in
[ "sanity test", t
; "empty chunk test", empty_chunk
; "pipelined chunk test", pipelined_chunk
; "no body when response is not modified", not_modified_has_no_body
; "pipelined with interleaving requests", pipelined_interleave
; "massive chunked", massive_chunked
; "unreadable file returns 500", unreadable_file_500
; "no leaks on requests", test_no_leak
; "expert response", expert_pipelined
let client_close () =
Cohttp_lwt_unix.Net.(connect_uri ~ctx:default_ctx) uri >>= fun (_conn, ic, oc) ->
let req = Cohttp.Request.make_for_client ~chunked:false `GET (Uri.with_path uri "/test.html") in
Request.write (fun _writer -> Lwt.return_unit) req oc
>>= fun () ->
Response.read ic >>= function
| `Eof | `Invalid _ -> assert false
| `Ok rsp ->
assert_equal ~printer:Cohttp.Code.string_of_status `OK (Cohttp.Response.status rsp);
Cohttp_lwt_unix.Net.close ic oc;
Lwt_condition.broadcast cond ();
Lwt.pause ()
in
[ "sanity test", check_logs t
; "empty chunk test", check_logs empty_chunk
; "pipelined chunk test", check_logs pipelined_chunk
; "no body when response is not modified", check_logs not_modified_has_no_body
; "pipelined with interleaving requests", check_logs pipelined_interleave
; "massive chunked", check_logs massive_chunked
; "unreadable file returns 500", unreadable_file_500
; "no leaks on requests", check_logs test_no_leak
; "expert response", check_logs expert_pipelined
; "client_close", check_logs client_close
]
end

Expand Down
2 changes: 1 addition & 1 deletion cohttp-lwt/src/cohttp_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
*
}}}*)

module type IO = S.IO with type 'a t = 'a Lwt.t
module type IO = S.IO

module Request = Cohttp.Request
module Response = Cohttp.Response
Expand Down
12 changes: 11 additions & 1 deletion cohttp-lwt/src/s.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@
functors must be instantiated by an implementation that provides
a concrete IO monad. *)

module type IO = Cohttp.S.IO with type 'a t = 'a Lwt.t
module type IO = sig
include Cohttp.S.IO with type 'a t = 'a Lwt.t

type error

val catch : (unit -> 'a t) -> ('a, error) result t
(** [catch f] is [f () >|= Result.ok], unless [f] fails with an IO error,
in which case it returns the error. *)

val pp_error : Format.formatter -> error -> unit
end
(** The IO module is specialized for the [Lwt] monad. *)

(** The [Net] module type defines how to connect to a remote node
Expand Down
9 changes: 7 additions & 2 deletions cohttp-lwt/src/server.ml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ module Make(IO:S.IO) = struct
(function
| Out_of_memory -> Lwt.fail Out_of_memory
| exn ->
Log.err (fun f -> f "Error handling %a: %s\n%!" Request.pp_hum req (Printexc.to_string exn));
Log.err (fun f -> f "Error handling %a: %s" Request.pp_hum req (Printexc.to_string exn));
respond_error ~body:"Internal Server Error" () >|= fun rsp ->
`Response rsp
))
Expand Down Expand Up @@ -147,7 +147,12 @@ module Make(IO:S.IO) = struct
let conn_closed () = spec.conn_closed (io_id,conn_id) in
Lwt.finalize
(fun () ->
handle_client ic oc (io_id,conn_id) spec.callback
IO.catch (fun () -> handle_client ic oc (io_id,conn_id) spec.callback)
>>= function
| Ok () -> Lwt.return_unit
| Error e ->
Log.info (fun m -> m "IO error while handling client: %a" IO.pp_error e);
Lwt.return_unit
)
(fun () ->
(* Clean up resources when the response stream terminates and call
Expand Down
35 changes: 29 additions & 6 deletions cohttp-mirage/src/io.ml
Original file line number Diff line number Diff line change
Expand Up @@ -21,38 +21,61 @@ open Lwt.Infix

module Make (Channel: Mirage_channel_lwt.S) = struct

type error =
| Read_error of Channel.error
| Write_error of Channel.write_error

let pp_error f = function
| Read_error e -> Channel.pp_error f e
| Write_error e -> Channel.pp_write_error f e

type 'a t = 'a Lwt.t
type ic = Channel.t
type oc = Channel.t
type conn = Channel.flow

let failf fmt = Fmt.kstrf Lwt.fail_with fmt
exception Read_exn of Channel.error
exception Write_exn of Channel.write_error

let () =
Printexc.register_printer (function
| Read_exn e -> Some (Format.asprintf "IO read error: %a" Channel.pp_error e)
| Write_exn e -> Some (Format.asprintf "IO write error: %a" Channel.pp_write_error e)
| _ -> None
)

let read_line ic =
Channel.read_line ic >>= function
| Ok (`Data []) -> Lwt.return_none
| Ok `Eof -> Lwt.return_none
| Ok (`Data bufs) -> Lwt.return (Some (Cstruct.copyv bufs))
| Error e -> failf "Flow error: %a" Channel.pp_error e
| Ok (`Data bufs) -> Lwt.return_some (Cstruct.copyv bufs)
| Error e -> Lwt.fail (Read_exn e)

let read ic len =
Channel.read_some ~len ic >>= function
| Ok (`Data buf) -> Lwt.return (Cstruct.to_string buf)
| Ok `Eof -> Lwt.return ""
| Error e -> failf "Flow error: %a" Channel.pp_error e
| Error e -> Lwt.fail (Read_exn e)

let write oc buf =
Channel.write_string oc buf 0 (String.length buf);
Channel.flush oc >>= function
| Ok () -> Lwt.return_unit
| Error `Closed -> Lwt.fail_with "Trying to write on closed channel"
| Error e -> failf "Flow error: %a" Channel.pp_write_error e
| Error e -> Lwt.fail (Write_exn e)

let flush _ =
(* NOOP since we flush in the normal writer functions above *)
Lwt.return_unit

let (>>= ) = Lwt.( >>= )
let (>>= ) = Lwt.( >>= )
let return = Lwt.return

let catch f =
Lwt.try_bind f Lwt.return_ok
(function
| Read_exn e -> Lwt.return_error (Read_error e)
| Write_exn e -> Lwt.return_error (Write_error e)
| ex -> Lwt.fail ex
)
end
5 changes: 2 additions & 3 deletions cohttp-mirage/src/io.mli
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

(** Cohttp IO implementation using Mirage channels. *)

module Make (Channel: Mirage_channel_lwt.S) : Cohttp.S.IO
with type 'a t = 'a Lwt.t
and type ic = Channel.t
module Make (Channel: Mirage_channel_lwt.S) : Cohttp_lwt.S.IO
with type ic = Channel.t
and type oc = Channel.t
and type conn = Channel.flow