diff --git a/lib/client_connection.ml b/lib/client_connection.ml index 13bf5993..54fa766a 100644 --- a/lib/client_connection.ml +++ b/lib/client_connection.ml @@ -130,7 +130,6 @@ let shutdown t = ;; let set_error_and_handle t error = - Reader.force_close t.reader; Queue.iter (fun respd -> match Respd.input_state respd with | Wait | Ready -> diff --git a/lib/respd.ml b/lib/respd.ml index e4d0107a..91df5226 100644 --- a/lib/respd.ml +++ b/lib/respd.ml @@ -68,7 +68,8 @@ let report_error t error = | Uninitialized, `Exn _ -> (* TODO(anmonteiro): Not entirely sure this is possible in the client. *) failwith "httpaf.Reqd.report_exn: NYI" - | Received_response _, `Ok -> + | Received_response (_, response_body), `Ok -> + Body.close_reader response_body; t.error_code <- (error :> [`Ok | error]); t.error_handler error | (Uninitialized | Awaiting_response | Received_response _ | Closed | Upgraded _), _ -> @@ -130,8 +131,6 @@ let flush_response_body t = match t.state with | Uninitialized | Awaiting_response | Closed | Upgraded _ -> () | Received_response(_, response_body) -> - try Body.execute_read response_body - (* TODO: report_exn *) - with exn -> - Format.eprintf "EXN %S@." (Printexc.to_string exn) - (* report_exn t exn *) + if Body.has_pending_output response_body + then try Body.execute_read response_body + with exn -> report_error t (`Exn exn) diff --git a/lib_test/test_client_connection.ml b/lib_test/test_client_connection.ml index b972e9c5..8c591f9e 100644 --- a/lib_test/test_client_connection.ml +++ b/lib_test/test_client_connection.ml @@ -923,6 +923,119 @@ let test_exception_closes_reader_persistent_connection () = connection_is_shutdown t; ;; +let test_exception_reading_response_body () = + let request' = Request.create `GET "/" in + let error_handler_called = ref false in + let response = + Response.create ~headers:(Headers.of_list [ "content-length", "10" ]) `OK + in + let t = create ?config:None in + let response_handler expected_response response body = + Alcotest.check (module Response) "expected response" expected_response response; + let on_read _ ~off:_ ~len:_ = failwith "something went wrong" in + let on_eof () = () in + Body.schedule_read body ~on_read ~on_eof; + in + let body = + request + t + request' + ~response_handler:(response_handler response) + ~error_handler:(fun _ -> error_handler_called := true;) + in + Body.close_writer body; + write_request t request'; + read_response t response; + read_string t "hello"; + Alcotest.(check bool) "Error handler called" true !error_handler_called; + writer_closed t; + reader_closed t; + connection_is_shutdown t; +;; + +let test_exception_reading_response_body_last_chunk () = + let writer_woken_up = ref false in + let request' = Request.create `GET "/" in + let error_handler_called = ref false in + let response = + Response.create ~headers:(Headers.of_list [ "content-length", "10" ]) `OK + in + let t = create ?config:None in + let response_handler expected_response response body = + Alcotest.check (module Response) "expected response" expected_response response; + let on_eof () = () in + let on_read _ ~off:_ ~len:_ = + Body.schedule_read + body + ~on_read:(fun _ ~off:_ ~len:_ -> + report_exn t (Failure "something went wrong")) + ~on_eof + in + Body.schedule_read body ~on_read ~on_eof; + in + let body = + request + t + request' + ~response_handler:(response_handler response) + ~error_handler:(fun _ -> error_handler_called := true;) + in + Body.close_writer body; + write_request t request'; + read_response t response; + writer_yielded t; + yield_writer t (fun () -> writer_woken_up := true); + read_string t "hello"; + reader_ready t; + read_string t "hello"; + Alcotest.(check bool) "Error handler called" true !error_handler_called; + Alcotest.(check bool) "Writer woken up" true !writer_woken_up; + writer_closed t; + reader_closed t; + connection_is_shutdown t; +;; + +let test_async_exception_reading_response_body () = + let reader_woken_up = ref false in + let writer_woken_up = ref false in + let continue_reading = ref (fun () -> ()) in + let request' = Request.create `GET "/" in + let error_handler_called = ref false in + let response = + Response.create ~headers:(Headers.of_list [ "content-length", "10" ]) `OK + in + let t = create ?config:None in + let response_handler expected_response response body = + Alcotest.check (module Response) "expected response" expected_response response; + let on_read _ ~off:_ ~len:_ = + continue_reading := (fun () -> report_exn t (Failure "something went wrong")) in + let on_eof () = () in + Body.schedule_read body ~on_read ~on_eof; + in + let body = + request + t + request' + ~response_handler:(response_handler response) + ~error_handler:(fun _ -> error_handler_called := true;) + in + Body.close_writer body; + write_request t request'; + read_response t response; + read_string t "hello"; + reader_yielded t; + yield_reader t (fun () -> reader_woken_up := true); + writer_yielded t; + yield_writer t (fun () -> writer_woken_up := true); + !continue_reading (); + Alcotest.(check bool) "Error handler called" true !error_handler_called; + Alcotest.(check bool) "Reader wakes up if scheduling read" true !reader_woken_up; + Alcotest.(check bool) "Writer woken up" true !writer_woken_up; + writer_closed t; + reader_closed t; + connection_is_shutdown t; +;; + let tests = [ "commit parse after every header line", `Quick, test_commit_parse_after_every_header ; "GET" , `Quick, test_get @@ -947,4 +1060,7 @@ let tests = ; "EOF after handler closed response body", `Quick, test_eof_handler_closed_response_body ; "Exception closes the reader (on a non-persistent connection)", `Quick, test_exception_closes_reader ; "Exception closes the reader (on a persistent connection)", `Quick, test_exception_closes_reader_persistent_connection + ; "Exception while reading the response body", `Quick, test_exception_reading_response_body + ; "Exception while reading the response body", `Quick, test_exception_reading_response_body_last_chunk + ; "Aynchronous exception while reading the response body", `Quick, test_async_exception_reading_response_body ] diff --git a/nix/sources.nix b/nix/sources.nix index b50c6179..94a4c83c 100644 --- a/nix/sources.nix +++ b/nix/sources.nix @@ -3,7 +3,7 @@ let overlays = builtins.fetchTarball - https://github.com/anmonteiro/nix-overlays/archive/f7e2af5.tar.gz; + https://github.com/anmonteiro/nix-overlays/archive/4b5b3d8.tar.gz; in