Skip to content

Commit

Permalink
client_connection: add upgrade support (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
anmonteiro committed May 11, 2020
1 parent 9bf50db commit 036e507
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 19 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ Unreleased
interface, don't require a `report_exn` function, only a `state` function
that returns the socket state
([#30](https://github.com/anmonteiro/httpaf/pull/30))
- httpaf, httpaf-lwt, httpaf-async: Add support for upgrading connections on
the client. ([#31](https://github.com/anmonteiro/httpaf/pull/31))

httpaf (upstream) 0.6.5
--------------
Expand Down
2 changes: 2 additions & 0 deletions async/httpaf_async.ml
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ module Client = struct
|> ignore;
reader_thread ()
end
| `Yield ->
Client_connection.yield_writer conn reader_thread;
| `Close ->
(* Log.Global.printf "read_close(%d)%!" (Fd.to_int_exn fd); *)
Ivar.fill read_complete ();
Expand Down
38 changes: 34 additions & 4 deletions lib/client_connection.ml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type t =
(* invariant: If [request_queue] is not empty, then the head of the queue
has already written the request headers to the wire. *)
; wakeup_writer : (unit -> unit) list ref
; wakeup_reader : (unit -> unit) list ref
}

let is_closed t =
Expand All @@ -63,6 +64,11 @@ let is_active t =
let current_respd_exn t =
Queue.peek t.request_queue

let on_wakeup_reader t k =
if is_closed t
then failwith "on_wakeup_reader on closed conn"
else t.wakeup_reader := k::!(t.wakeup_reader)

let on_wakeup_writer t k =
if is_closed t
then failwith "on_wakeup_writer on closed conn"
Expand All @@ -73,13 +79,19 @@ let wakeup_writer t =
t.wakeup_writer := [];
List.iter (fun f -> f ()) fs

let wakeup_reader t =
let fs = !(t.wakeup_reader) in
t.wakeup_reader := [];
List.iter (fun f -> f ()) fs

let[@ocaml.warning "-16"] create ?(config=Config.default) =
let request_queue = Queue.create () in
{ config
; reader = Reader.response request_queue
; writer = Writer.create ()
; request_queue
; wakeup_writer = ref []
; wakeup_reader = ref []
}

let request t request ~error_handler ~response_handler =
Expand Down Expand Up @@ -123,6 +135,7 @@ let shutdown_reader t =
Reader.force_close t.reader;
if is_active t
then Respd.close_response_body (current_respd_exn t)
else wakeup_reader t

let shutdown_writer t =
flush_request_body t;
Expand All @@ -136,6 +149,7 @@ let shutdown_writer t =
let shutdown t =
shutdown_reader t;
shutdown_writer t;
wakeup_reader t;
wakeup_writer t;
;;

Expand Down Expand Up @@ -207,17 +221,29 @@ let advance_request_queue_if_necessary t =
end else if Reader.is_closed t.reader
then shutdown t

let next_read_operation t =
let _next_read_operation t =
advance_request_queue_if_necessary t;
match Reader.next t.reader with
if is_active t then begin
let respd = current_respd_exn t in
if Respd.requires_input respd then Reader.next t.reader
else if Respd.persistent_connection respd then `Yield
else begin
shutdown_reader t;
Reader.next t.reader
end
end else
Reader.next t.reader

let next_read_operation t =
match _next_read_operation t with
| `Error (`Parse(marks, message)) ->
let message = String.concat "" [ String.concat ">" marks; ": "; message] in
set_error_and_handle t (`Malformed_response message);
`Close
| `Error (`Invalid_response_body_length _ as error) ->
set_error_and_handle t error;
`Close
| (`Read | `Close) as operation -> operation
| (`Read | `Yield | `Close) as operation -> operation
;;

let read_with_more t bs ~off ~len more =
Expand All @@ -237,7 +263,7 @@ let read_eof t bs ~off ~len =
(* TODO: could just check for `Respd.requires_input`? *)
match respd.state with
| Uninitialized -> assert false
| Received_response _ | Closed -> ()
| Received_response _ | Closed | Upgraded _ -> ()
| Awaiting_response ->
(* TODO: review this. It makes sense to tear down the connection if an
* unexpected EOF is received. *)
Expand All @@ -253,6 +279,10 @@ let next_write_operation t =
Writer.next t.writer
;;

let yield_reader t k =
on_wakeup_reader t k
;;

let yield_writer t k =
if is_active t then begin
let respd = current_respd_exn t in
Expand Down
4 changes: 3 additions & 1 deletion lib/httpaf.mli
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ module Client_connection : sig
-> response_handler:response_handler
-> [`write] Body.t

val next_read_operation : t -> [ `Read | `Close ]
val next_read_operation : t -> [ `Read | `Yield | `Close ]
(** [next_read_operation t] returns a value describing the next operation
that the caller should conduct on behalf of the connection. *)

Expand Down Expand Up @@ -808,6 +808,8 @@ module Client_connection : sig
{- [`Closed] indicates that the output destination will no longer
accept bytes from the write processor. }} *)

val yield_reader : t -> (unit -> unit) -> unit

val yield_writer : t -> (unit -> unit) -> unit
(** [yield_writer t continue] registers with the connection to call
[continue] when writing should resume. {!yield_writer} should be called
Expand Down
35 changes: 26 additions & 9 deletions lib/respd.ml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ type state =
| Uninitialized
| Awaiting_response
| Received_response of Response.t * [`read] Body.t
| Upgraded of Response.t
| Closed

type t =
Expand All @@ -27,7 +28,13 @@ let create error_handler request request_body writer response_handler =
let t = Lazy.force t in
if t.persistent then
t.persistent <- Response.persistent_connection response;
t.state <- Received_response(response, body);
let next_state = match response.status with
| `Switching_protocols ->
Upgraded response
| _ ->
Received_response (response, body)
in
t.state <- next_state;
response_handler response body
and t =
lazy
Expand Down Expand Up @@ -58,14 +65,14 @@ let report_error t error =
(* t.persistent <- false; *)
(* TODO: drain queue? *)
match t.state, t.error_code with
| (Uninitialized | Awaiting_response | Received_response _), `Ok ->
| (Uninitialized | Awaiting_response | Received_response _ | Upgraded _), `Ok ->
t.state <- Closed;
t.error_code <- (error :> [`Ok | error]);
t.error_handler error
| Uninitialized, `Exn _ ->
(* TODO(anmonteiro): Not entirely sure this is possible in the client. *)
failwith "httpaf.Reqd.report_exn: NYI"
| (Uninitialized | Awaiting_response | Received_response _ | Closed), _ ->
| (Uninitialized | Awaiting_response | Received_response _ | Closed | Upgraded _), _ ->
(* XXX(seliopou): Once additional logging support is added, log the error
* in case it is not spurious. *)
()
Expand All @@ -80,36 +87,46 @@ let close_response_body t =
| Closed -> ()
| Received_response (_, response_body) ->
Body.close_reader response_body
| Upgraded _ -> t.state <- Closed

let requires_input t =
match t.state with
| Uninitialized -> true
| Awaiting_response -> true
| Upgraded _ -> false
| Received_response (_, response_body) ->
not (Body.is_closed response_body)
| Closed -> false

let requires_output { request_body; state; _ } =
state = Uninitialized ||
not (Body.is_closed request_body) ||
Body.has_pending_output request_body
match state with
| Upgraded _ ->
(* XXX(anmonteiro): Connections that have been upgraded "require output"
* forever, but outside the HTTP layer, meaning they're permanently
* "yielding". For now they need to be explicitly shutdown in order to
* transition the response descriptor to the `Closed` state. *)
true
| state ->
state = Uninitialized ||
not (Body.is_closed request_body) ||
Body.has_pending_output request_body

let is_complete t =
not (requires_input t || requires_output t)

let flush_request_body { request; request_body; writer; _ } =
if Body.has_pending_output request_body
then
if Body.has_pending_output request_body then begin
let encoding =
match Request.body_length request with
| `Fixed _ | `Chunked as encoding -> encoding
| `Error _ -> assert false (* XXX(seliopou): This needs to be handled properly *)
in
Body.transfer_to_writer_with_encoding request_body ~encoding writer
end

let flush_response_body t =
match t.state with
| Uninitialized | Awaiting_response | Closed -> ()
| Uninitialized | Awaiting_response | Closed | Upgraded _ -> ()
| Received_response(_, response_body) ->
try Body.execute_read response_body
(* TODO: report_exn *)
Expand Down
40 changes: 36 additions & 4 deletions lib_test/test_client_connection.ml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ let reader_ready t =
`Read (next_read_operation t :> [`Close | `Read | `Yield]);
;;

let reader_yielded t =
Alcotest.check read_operation "Reader is in a yield state"
`Yield (next_read_operation t :> [`Close | `Read | `Yield]);
;;

let write_string ?(msg="output written") t str =
let len = String.length str in
Alcotest.(check (option string)) msg
Expand Down Expand Up @@ -437,7 +442,7 @@ let test_partial_input () =
request
t
request'
~response_handler:response_handler
~response_handler
~error_handler:no_error_handler
in
write_request t request';
Expand Down Expand Up @@ -471,7 +476,7 @@ let test_empty_fixed_body () =
request
t
request'
~response_handler:response_handler
~response_handler
~error_handler:no_error_handler
in
write_request t request';
Expand Down Expand Up @@ -502,7 +507,7 @@ let test_fixed_body () =
request
t
request'
~response_handler:response_handler
~response_handler
~error_handler:no_error_handler
in
write_request t request';
Expand Down Expand Up @@ -536,7 +541,7 @@ let test_fixed_body_persistent_connection () =
request
t
request'
~response_handler:response_handler
~response_handler
~error_handler:no_error_handler
in
write_request t request';
Expand All @@ -548,6 +553,32 @@ let test_fixed_body_persistent_connection () =
writer_yielded t;
;;

let test_client_upgrade () =
let request' = Request.create
~headers:(Headers.of_list ["Content-Length", "0"])
`GET "/"
in
let t = create ?config:None in
let response = Response.create `Switching_protocols in
let body =
request
t
request'
~response_handler:(default_response_handler response)
~error_handler:no_error_handler
in
write_request t request';
writer_yielded t;
Body.close_writer body;
reader_ready t;
read_response t response;
reader_yielded t;
writer_yielded t;
shutdown t;
reader_closed t;
writer_closed t;
;;

let tests =
[ "GET" , `Quick, test_get
; "Response EOF", `Quick, test_response_eof
Expand All @@ -563,6 +594,7 @@ let tests =
; "Empty fixed body shuts down writer", `Quick, test_empty_fixed_body
; "Fixed body shuts down writer if connection is not persistent", `Quick, test_fixed_body
; "Fixed body doesn't shut down the writer if connection is persistent",`Quick, test_fixed_body_persistent_connection
; "Client support for upgrading a connection", `Quick, test_client_upgrade
]

(*
Expand Down
6 changes: 5 additions & 1 deletion lwt/httpaf_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ module Client (Io: IO) = struct
let read_buffer = Buffer.create config.read_buffer_size in
let read_loop_exited, notify_read_loop_exited = Lwt.wait () in

let read_loop () =
let rec read_loop () =
let rec read_loop_step () =
match Client_connection.next_read_operation connection with
| `Read ->
Expand All @@ -224,6 +224,10 @@ module Client (Io: IO) = struct
read_loop_step ()
end

| `Yield ->
Client_connection.yield_reader connection read_loop;
Lwt.return_unit

| `Close ->
Lwt.wakeup_later notify_read_loop_exited ();
Io.shutdown_receive socket;
Expand Down

0 comments on commit 036e507

Please sign in to comment.