diff --git a/async/httpaf_async.ml b/async/httpaf_async.ml index accd3eec..cd2b2293 100644 --- a/async/httpaf_async.ml +++ b/async/httpaf_async.ml @@ -90,7 +90,8 @@ let read fd buffer = open Httpaf module Server = struct - let create_connection_handler ?(config=Config.default) ~request_handler ~error_handler = + let create_connection_handler + ?(config=Config.default) ~request_handler ~error_handler ~upgrade_handler = fun client_addr socket -> let fd = Socket.fd socket in let writev = Faraday_async.writev_of_fd fd in @@ -98,6 +99,19 @@ module Server = struct let error_handler = error_handler client_addr in let conn = Server_connection.create ~config ~error_handler request_handler in let read_complete = Ivar.create () in + let write_complete = Ivar.create () in + let upgrade_read, upgrade_write = Ivar.create (), Ivar.create () in + upon + (Deferred.both (Ivar.read upgrade_read) (Ivar.read upgrade_write)) + (fun ((), ()) -> + match upgrade_handler with + | None -> failwith "HTTP upgrades not supported" + | Some upgrade_handler -> + upgrade_handler client_addr socket + >>> fun () -> + if not (Fd.is_closed fd) then Socket.shutdown socket `Both; + Ivar.fill read_complete (); + Ivar.fill write_complete ()); let buffer = Buffer.create config.read_buffer_size in let rec reader_thread () = match Server_connection.next_read_operation conn with @@ -119,13 +133,13 @@ module Server = struct | `Yield -> (* Log.Global.printf "read_yield(%d)%!" (Fd.to_int_exn fd); *) Server_connection.yield_reader conn reader_thread + | `Upgrade -> Ivar.fill upgrade_read () | `Close -> (* Log.Global.printf "read_close(%d)%!" (Fd.to_int_exn fd); *) Ivar.fill read_complete (); if not (Fd.is_closed fd) then Socket.shutdown socket `Receive in - let write_complete = Ivar.create () in let rec writer_thread () = match Server_connection.next_write_operation conn with | `Write iovecs -> @@ -136,6 +150,7 @@ module Server = struct | `Yield -> (* Log.Global.printf "write_yield(%d)%!" (Fd.to_int_exn fd); *) Server_connection.yield_writer conn writer_thread; + | `Upgrade -> Ivar.fill upgrade_write () | `Close _ -> (* Log.Global.printf "write_close(%d)%!" (Fd.to_int_exn fd); *) Ivar.fill write_complete (); diff --git a/async/httpaf_async.mli b/async/httpaf_async.mli index cbb83f6b..9c3eb82b 100644 --- a/async/httpaf_async.mli +++ b/async/httpaf_async.mli @@ -8,6 +8,7 @@ module Server : sig : ?config : Config.t -> request_handler : ('a -> Server_connection.request_handler) -> error_handler : ('a -> Server_connection.error_handler) + -> upgrade_handler : ('a -> ([`Active], 'a) Socket.t -> unit Deferred.t) option -> ([< Socket.Address.t] as 'a) -> ([`Active], 'a) Socket.t -> unit Deferred.t diff --git a/examples/async/async_echo_post.ml b/examples/async/async_echo_post.ml index 0f524a2a..b0df8db5 100644 --- a/examples/async/async_echo_post.ml +++ b/examples/async/async_echo_post.ml @@ -10,7 +10,10 @@ let main port max_accepts_per_batch () = let where_to_listen = Tcp.Where_to_listen.of_port port in Tcp.(Server.create_sock ~on_handler_error:`Raise ~backlog:10_000 ~max_connections:10_000 ~max_accepts_per_batch where_to_listen) - (Server.create_connection_handler ~request_handler ~error_handler) + (Server.create_connection_handler + ~request_handler + ~error_handler + ~upgrade_handler:None) >>= fun _server -> Stdio.printf "Listening on port %i and echoing POST requests.\n" port; Stdio.printf "To send a POST request, try one of the following\n\n"; diff --git a/examples/async/async_echo_upgrade.ml b/examples/async/async_echo_upgrade.ml new file mode 100644 index 00000000..cb0c9c23 --- /dev/null +++ b/examples/async/async_echo_upgrade.ml @@ -0,0 +1,47 @@ +open Core +open Async + +open Httpaf_async + +let request_handler (_ : Socket.Address.Inet.t) = Httpaf_examples.Server.upgrade +let error_handler (_ : Socket.Address.Inet.t) = Httpaf_examples.Server.error_handler + +let upgrade_handler (_ : Socket.Address.Inet.t) socket = + let fd = Socket.fd socket in + let reader = Reader.create fd in + let writer = Writer.create fd in + Reader.read_one_chunk_at_a_time reader ~handle_chunk:(fun bigstring ~pos ~len -> + Writer.write_bigstring writer bigstring ~pos ~len; + return `Continue) + >>| function + | `Eof | `Stopped _ | `Eof_with_unconsumed_data _ -> () +;; + +let main port max_accepts_per_batch () = + let where_to_listen = Tcp.Where_to_listen.of_port port in + Tcp.(Server.create_sock ~on_handler_error:`Raise + ~backlog:10_000 ~max_connections:10_000 ~max_accepts_per_batch where_to_listen) + (Server.create_connection_handler + ~request_handler + ~error_handler + ~upgrade_handler:(Some upgrade_handler)) + >>= fun _server -> + Stdio.printf "Listening on port %i, upgrading, and echoing data.\n" port; + Stdio.printf "To send an interactive upgrade request, try\n\n"; + Stdio.printf " examples/script/upgrade-connect\n%!"; + Deferred.never () +;; + +let () = + Command.async + ~summary:"Echo POST requests" + Command.Param.( + map (both + (flag "-p" (optional_with_default 8080 int) + ~doc:"int Source port to listen on") + (flag "-a" (optional_with_default 1 int) + ~doc:"int Maximum accepts per batch")) + ~f:(fun (port, accepts) -> + (fun () -> main port accepts ()))) + |> Command.run +;; diff --git a/examples/async/dune b/examples/async/dune index 4008a21d..b107d5cf 100644 --- a/examples/async/dune +++ b/examples/async/dune @@ -1,6 +1,6 @@ (executables (libraries httpaf httpaf-async httpaf_examples async core) - (names async_echo_post async_get async_post)) + (names async_echo_post async_echo_upgrade async_get async_post)) (alias (name examples) diff --git a/examples/lib/httpaf_examples.ml b/examples/lib/httpaf_examples.ml index fd049718..0dca1c91 100644 --- a/examples/lib/httpaf_examples.ml +++ b/examples/lib/httpaf_examples.ml @@ -86,4 +86,14 @@ module Server = struct end; Body.Writer.close response_body ;; + + let upgrade reqd = + if Request.is_upgrade (Reqd.request reqd) then ( + let headers = Headers.of_list [ "connection", "upgrade" ] in + Reqd.respond_with_upgrade reqd headers; + ) else ( + let headers = Headers.of_list [ "connection", "close" ] in + Reqd.respond_with_string reqd (Response.create ~headers `Not_found) "" + ) + ;; end diff --git a/examples/lwt/dune b/examples/lwt/dune index fe8f8b0a..7760e3c2 100644 --- a/examples/lwt/dune +++ b/examples/lwt/dune @@ -1,6 +1,6 @@ (executables (libraries httpaf httpaf-lwt-unix httpaf_examples base stdio lwt lwt.unix) - (names lwt_get lwt_post lwt_echo_post)) + (names lwt_get lwt_post lwt_echo_post lwt_echo_upgrade)) (alias (name examples) diff --git a/examples/lwt/lwt_echo_post.ml b/examples/lwt/lwt_echo_post.ml index 18307107..42d95b9f 100644 --- a/examples/lwt/lwt_echo_post.ml +++ b/examples/lwt/lwt_echo_post.ml @@ -12,7 +12,10 @@ let main port = Lwt.async (fun () -> Lwt_io.establish_server_with_client_socket listen_address - (Server.create_connection_handler ~request_handler ~error_handler) + (Server.create_connection_handler + ~request_handler + ~error_handler + ~upgrade_handler:None) >|= fun _server -> Stdio.printf "Listening on port %i and echoing POST requests.\n" port; Stdio.printf "To send a POST request, try one of the following\n\n"; diff --git a/examples/lwt/lwt_echo_upgrade.ml b/examples/lwt/lwt_echo_upgrade.ml new file mode 100644 index 00000000..e0acd5a0 --- /dev/null +++ b/examples/lwt/lwt_echo_upgrade.ml @@ -0,0 +1,47 @@ +open Base +open Lwt.Infix +module Arg = Caml.Arg + +open Httpaf_lwt_unix + +let request_handler (_ : Unix.sockaddr) = Httpaf_examples.Server.upgrade +let error_handler (_ : Unix.sockaddr) = Httpaf_examples.Server.error_handler + +let upgrade_handler (_ : Unix.sockaddr) (fd : Lwt_unix.file_descr) = + let input = Lwt_io.of_fd fd ~mode:Input in + let output = Lwt_io.of_fd fd ~mode:Output in + let rec loop () = + Lwt_io.read input ~count:4096 + >>= function + | "" -> Lwt.return_unit + | data -> Lwt_io.write output data >>= loop + in + loop () +;; + +let main port = + let listen_address = Unix.(ADDR_INET (inet_addr_loopback, port)) in + Lwt.async (fun () -> + Lwt_io.establish_server_with_client_socket + listen_address + (Server.create_connection_handler + ~request_handler + ~error_handler + ~upgrade_handler:(Some upgrade_handler)) + >|= fun _server -> + Stdio.printf "Listening on port %i, upgrading, and echoing data.\n" port; + Stdio.printf "To send an interactive upgrade request, try\n\n"; + Stdio.printf " examples/script/upgrade-connect\n%!"); + let forever, _ = Lwt.wait () in + Lwt_main.run forever +;; + +let () = + let port = ref 8080 in + Arg.parse + ["-p", Arg.Set_int port, " Listening port number (8080 by default)"] + ignore + "Echoes POST requests. Runs forever."; + main !port +;; + diff --git a/examples/script/upgrade-connect b/examples/script/upgrade-connect new file mode 100755 index 00000000..416d30b7 --- /dev/null +++ b/examples/script/upgrade-connect @@ -0,0 +1,13 @@ +#!/usr/bin/env bash +set -euo pipefail + +function headers { + printf "\ +GET / HTTP/1.1\r +Host: localhost\r +Connection: upgrade\r +\r +" +} + +( headers; echo hello; cat; echo bye ) | nc localhost 8080 --close diff --git a/lib/httpaf.mli b/lib/httpaf.mli index efb05ade..aea08a1f 100644 --- a/lib/httpaf.mli +++ b/lib/httpaf.mli @@ -564,6 +564,10 @@ module Request : sig more details. *) val pp_hum : Format.formatter -> t -> unit [@@ocaml.toplevel_printer] + + val is_upgrade : t -> bool + (** [is_upgrade t] returns true if the request has the "Connection: upgrade" + header. *) end @@ -659,6 +663,16 @@ module Reqd : sig val respond_with_bigstring : t -> Response.t -> Bigstringaf.t -> unit val respond_with_streaming : ?flush_headers_immediately:bool -> t -> Response.t -> Body.Writer.t + val respond_with_upgrade : ?reason:string -> t -> Headers.t -> unit + (** Initiate an HTTP upgrade. [Server_connection.next_write_request] and + [next_read_request] will begin returning [`Upgrade] once the response + headers have been written, which indicates that the runtime should take + over direct control of the socket rather than shuttling bytes through + httpaf. + + The headers must indicate a valid upgrade message, e.g. must include + "Connection: upgrade". See [Request.is_upgrade]. *) + (** {3 Exception Handling} *) val report_exn : t -> exn -> unit @@ -700,7 +714,7 @@ module Server_connection : sig (** [create ?config ?error_handler ~request_handler] creates a connection handler that will service individual requests with [request_handler]. *) - val next_read_operation : t -> [ `Read | `Yield | `Close ] + val next_read_operation : t -> [ `Read | `Yield | `Close | `Upgrade ] (** [next_read_operation t] returns a value describing the next operation that the caller should conduct on behalf of the connection. *) @@ -727,6 +741,7 @@ module Server_connection : sig val next_write_operation : t -> [ | `Write of Bigstringaf.t IOVec.t list | `Yield + | `Upgrade | `Close of int ] (** [next_write_operation t] returns a value describing the next operation that the caller should conduct on behalf of the connection. *) diff --git a/lib/parse.ml b/lib/parse.ml index 1fedbc23..32533e1c 100644 --- a/lib/parse.ml +++ b/lib/parse.ml @@ -245,6 +245,8 @@ module Reader = struct | `Fixed 0L -> handler request Body.Reader.empty; ok + | `Fixed _ | `Chunked when Request.is_upgrade request -> + return (Error (`Bad_request request)) | `Fixed _ | `Chunked as encoding -> let request_body = Body.Reader.create Bigstringaf.empty in handler request request_body; diff --git a/lib/reqd.ml b/lib/reqd.ml index 5569bf1a..2d09d3c1 100644 --- a/lib/reqd.ml +++ b/lib/reqd.ml @@ -37,14 +37,17 @@ type error = module Response_state = struct type t = | Waiting + | Upgrade of Response.t | Fixed of Response.t | Streaming of Response.t * Body.Writer.t end module Input_state = struct type t = + | Waiting | Ready | Complete + | Upgraded end module Output_state = struct @@ -52,6 +55,7 @@ module Output_state = struct | Waiting | Ready | Complete + | Upgraded end type error_handler = @@ -111,12 +115,14 @@ let response { response_state; _ } = match response_state with | Waiting -> None | Streaming (response, _) + | Upgrade response | Fixed response -> Some response let response_exn { response_state; _ } = match response_state with | Waiting -> failwith "httpaf.Reqd.response_exn: response has not started" | Streaming (response, _) + | Upgrade response | Fixed response -> response let respond_with_string t response str = @@ -133,6 +139,7 @@ let respond_with_string t response str = Writer.wakeup t.writer; | Streaming _ -> failwith "httpaf.Reqd.respond_with_string: response already started" + | Upgrade _ | Fixed _ -> failwith "httpaf.Reqd.respond_with_string: response already complete" @@ -150,6 +157,7 @@ let respond_with_bigstring t response (bstr:Bigstringaf.t) = Writer.wakeup t.writer; | Streaming _ -> failwith "httpaf.Reqd.respond_with_bigstring: response already started" + | Upgrade _ | Fixed _ -> failwith "httpaf.Reqd.respond_with_bigstring: response already complete" @@ -175,6 +183,7 @@ let unsafe_respond_with_streaming ~flush_headers_immediately t response = response_body | Streaming _ -> failwith "httpaf.Reqd.respond_with_streaming: response already started" + | Upgrade _ | Fixed _ -> failwith "httpaf.Reqd.respond_with_streaming: response already complete" @@ -183,6 +192,25 @@ let respond_with_streaming ?(flush_headers_immediately=false) t response = failwith "httpaf.Reqd.respond_with_streaming: invalid state, currently handling error"; unsafe_respond_with_streaming ~flush_headers_immediately t response +let respond_with_upgrade ?reason t headers = + match t.response_state with + | Waiting -> + if not (Request.is_upgrade t.request) then + failwith "httpaf.Reqd.respond_with_upgrade: request was not an upgrade request" + else ( + let response = Response.create ?reason ~headers `Switching_protocols in + t.response_state <- Upgrade response; + (* The parser ensures it only passes empty bodies in the case of an + upgrade request *) + assert (Body.Reader.is_closed t.request_body); + Writer.write_response t.writer response; + Writer.wakeup t.writer); + | Streaming _ -> + failwith "httpaf.Reqd.respond_with_upgrade: response already started" + | Upgrade _ + | Fixed _ -> + failwith "httpaf.Reqd.respond_with_upgrade: response already complete" + let report_error t error = t.persistent <- false; Body.Reader.close t.request_body; @@ -207,7 +235,7 @@ let report_error t error = | Streaming (_response, response_body), `Exn _ -> Body.Writer.close response_body; Writer.close_and_drain t.writer - | (Fixed _ | Streaming _ | Waiting) , _ -> + | (Fixed _ | Streaming _ | Waiting | Upgrade _) , _ -> (* XXX(seliopou): Once additional logging support is added, log the error * in case it is not spurious. *) () @@ -232,13 +260,18 @@ let persistent_connection t = t.persistent let input_state t : Input_state.t = - if Body.Reader.is_closed t.request_body - then Complete - else Ready + match t.response_state with + | Upgrade _ -> Upgraded + | Waiting when Request.is_upgrade t.request -> Waiting + | Waiting | Fixed _ | Streaming _ -> + if Body.Reader.is_closed t.request_body + then Complete + else Ready ;; let output_state t : Output_state.t = match t.response_state with + | Upgrade _ -> Upgraded | Fixed _ -> Complete | Streaming (_, response_body) -> if Body.Writer.has_pending_output response_body diff --git a/lib/request.ml b/lib/request.ml index 158fd4c2..204a2000 100644 --- a/lib/request.ml +++ b/lib/request.ml @@ -80,3 +80,8 @@ let persistent_connection ?proxy { version; headers; _ } = let pp_hum fmt { meth; target; version; headers } = Format.fprintf fmt "((method \"%a\") (target %S) (version \"%a\") (headers %a))" Method.pp_hum meth target Version.pp_hum version Headers.pp_hum headers + +let is_upgrade t = + match Headers.get t.headers "Connection" with + | None -> false + | Some header_val -> Headers.ci_equal header_val "upgrade" diff --git a/lib/serialize.ml b/lib/serialize.ml index 61c22131..76039b9d 100644 --- a/lib/serialize.ml +++ b/lib/serialize.ml @@ -195,4 +195,6 @@ module Writer = struct | `Close -> `Close (drained_bytes t) | `Yield -> `Yield | `Writev iovecs -> `Write iovecs + + let has_pending_output t = Faraday.has_pending_output t.encoder end diff --git a/lib/server_connection.ml b/lib/server_connection.ml index 839a10b3..1689e436 100644 --- a/lib/server_connection.ml +++ b/lib/server_connection.ml @@ -225,8 +225,10 @@ let rec _next_read_operation t = ) else ( let reqd = current_reqd_exn t in match Reqd.input_state reqd with + | Waiting -> _yield_reader t | Ready -> Reader.next t.reader | Complete -> _final_read_operation_for t reqd + | Upgraded -> `Upgrade ) and _final_read_operation_for t reqd = @@ -235,30 +237,37 @@ and _final_read_operation_for t reqd = Reader.next t.reader; ) else ( match Reqd.output_state reqd with - | Waiting | Ready -> - (* XXX(dpatti): This is a way in which the reader and writer are not - parallel -- we tell the writer when it needs to yield but the reader is - always asking for more data. This is the only branch in either - operation function that does not return `(Reader|Writer).next`, which - means there are surprising states you can get into. For example, we ask - the runtime to yield but then raise when it tries to because the reader - is closed. I don't think checking `is_closed` here makes sense - semantically, but I don't think checking it in `_next_read_operation` - makes sense either. I chose here so I could describe why. *) - if Reader.is_closed t.reader - then Reader.next t.reader - else `Yield + | Waiting | Ready -> _yield_reader t + | Upgraded -> + (* If the input state is not [Upgraded], the output state cannot be + either. *) + assert false | Complete -> advance_request_queue t; _next_read_operation t; ) + +and _yield_reader t = + (* XXX(dpatti): This is a way in which the reader and writer are not + parallel -- we tell the writer when it needs to yield but the reader is + always asking for more data. This is the only branch in either + operation function that does not return `(Reader|Writer).next`, which + means there are surprising states you can get into. For example, we ask + the runtime to yield but then raise when it tries to because the reader + is closed. I think this can be avoided if we allow this module to tell the + reader when it should yield/resume, then we'd just do an inlined + `Reader.next` call instead. I put this function here to describe why this + is subtle. *) + if Reader.is_closed t.reader + then Reader.next t.reader + else `Yield ;; let next_read_operation t = match _next_read_operation t with | `Error (`Parse _) -> set_error_and_handle t `Bad_request; `Close | `Error (`Bad_request request) -> set_error_and_handle ~request t `Bad_request; `Close - | (`Read | `Yield | `Close) as operation -> operation + | (`Read | `Yield | `Close | `Upgrade) as operation -> operation let rec read_with_more t bs ~off ~len more = let call_handler = Queue.is_empty t.request_queue in @@ -297,7 +306,13 @@ let rec _next_write_operation t = Reqd.flush_response_body reqd; Writer.next t.writer | Complete -> _final_write_operation_for t reqd - ) + | Upgraded -> + wakeup_reader t; + (* Even in the Upgrade case, we're still responsible for writing the + response header, so we might have work to do. *) + if Writer.has_pending_output t.writer + then Writer.next t.writer + else `Upgrade) and _final_write_operation_for t reqd = let next = @@ -306,7 +321,9 @@ and _final_write_operation_for t reqd = Writer.next t.writer; ) else ( match Reqd.input_state reqd with + | Waiting -> `Yield | Ready -> Writer.next t.writer; + | Upgraded -> `Upgrade | Complete -> advance_request_queue t; _next_write_operation t; diff --git a/lib_test/helpers.ml b/lib_test/helpers.ml index 9ce0be45..eb4840af 100644 --- a/lib_test/helpers.ml +++ b/lib_test/helpers.ml @@ -18,7 +18,7 @@ let response_to_string ?body r = Faraday.serialize_to_string f module Read_operation = struct - type t = [ `Read | `Yield | `Close ] + type t = [ `Read | `Yield | `Close | `Upgrade ] let pp_hum fmt (t : t) = let str = @@ -26,13 +26,14 @@ module Read_operation = struct | `Read -> "Read" | `Yield -> "Yield" | `Close -> "Close" + | `Upgrade -> "Upgrade" in Format.pp_print_string fmt str ;; end module Write_operation = struct - type t = [ `Write of Bigstringaf.t IOVec.t list | `Yield | `Close of int ] + type t = [ `Write of Bigstringaf.t IOVec.t list | `Yield | `Close of int | `Upgrade ] let iovecs_to_string iovecs = let len = IOVec.lengthv iovecs in @@ -50,12 +51,13 @@ module Write_operation = struct | `Write iovecs -> Format.fprintf fmt "Write %S" (iovecs_to_string iovecs) | `Yield -> Format.pp_print_string fmt "Yield" | `Close len -> Format.fprintf fmt "Close %i" len + | `Upgrade -> Format.pp_print_string fmt "Upgrade" ;; let to_write_as_string t = match t with | `Write iovecs -> Some (iovecs_to_string iovecs) - | `Close _ | `Yield -> None + | `Close _ | `Yield | `Upgrade -> None ;; end @@ -70,4 +72,5 @@ module Headers = struct let connection_close = Headers.of_list ["connection", "close"] let encoding_chunked = Headers.of_list ["transfer-encoding", "chunked"] let encoding_fixed n = Headers.of_list ["content-length", string_of_int n] + let upgrade protocol = Headers.of_list ["connection", "upgrade" ; "upgrade", protocol] end diff --git a/lib_test/test_client_connection.ml b/lib_test/test_client_connection.ml index 40542bdc..471e5ebe 100644 --- a/lib_test/test_client_connection.ml +++ b/lib_test/test_client_connection.ml @@ -41,12 +41,12 @@ let read_response t r = let reader_ready t = Alcotest.check read_operation "Reader is ready" - `Read (next_read_operation t :> [`Close | `Read | `Yield]); + `Read (next_read_operation t :> Read_operation.t); ;; let reader_closed t = Alcotest.check read_operation "Reader is closed" - `Close (next_read_operation t :> [`Close | `Read | `Yield]); + `Close (next_read_operation t :> Read_operation.t); ;; let write_string ?(msg="output written") t str = @@ -64,17 +64,17 @@ let write_request ?(msg="request written") t r = let writer_yielded t = Alcotest.check write_operation "Writer is in a yield state" - `Yield (next_write_operation t); + `Yield (next_write_operation t :> Write_operation.t); ;; let writer_closed t = Alcotest.check write_operation "Writer is closed" - (`Close 0) (next_write_operation t); + (`Close 0) (next_write_operation t :> Write_operation.t); ;; let connection_is_shutdown t = Alcotest.check read_operation "Reader is closed" - `Close (next_read_operation t :> [`Close | `Read | `Yield]); + `Close (next_read_operation t :> Read_operation.t); writer_closed t; ;; diff --git a/lib_test/test_server_connection.ml b/lib_test/test_server_connection.ml index 986399ab..1f012171 100644 --- a/lib_test/test_server_connection.ml +++ b/lib_test/test_server_connection.ml @@ -72,6 +72,9 @@ end = struct | `Close -> trace "reader: Close"; t.read_operation <- `Close + | `Upgrade -> + trace "reader: Upgrade"; + t.read_operation <- `Upgrade ;; let rec write_step t = @@ -91,6 +94,9 @@ end = struct | `Close n -> trace "writer: Close"; t.write_operation <- `Close n + | `Upgrade -> + trace "writer: Upgrade"; + t.write_operation <- `Upgrade ;; let create ?config ?error_handler request_handler = @@ -123,13 +129,13 @@ end = struct let current_read_operation t = match t.read_operation with | `Initial -> assert false - | `Read | `Yield | `Close as op -> op + | `Read | `Yield | `Close | `Upgrade as op -> op ;; let current_write_operation t = match t.write_operation with | `Initial -> assert false - | `Write _ | `Yield | `Close _ as op -> op + | `Write _ | `Yield | `Close _ | `Upgrade as op -> op ;; let do_read t f = @@ -140,9 +146,9 @@ end = struct trace "read: finished"; t.read_loop (); res - | `Yield | `Close as op -> - Alcotest.failf "Read attempted during operation: %a" - Read_operation.pp_hum op + | `Yield | `Close | `Upgrade as op -> + Alcotest.failf "Read attempted during operation: %a" + Read_operation.pp_hum op ;; let do_write t f = @@ -153,7 +159,7 @@ end = struct trace "write: finished"; t.write_loop (); res - | `Yield | `Close _ as op -> + | `Yield | `Close _ | `Upgrade as op -> Alcotest.failf "Write attempted during operation: %a" Write_operation.pp_hum op ;; @@ -219,6 +225,11 @@ let reader_closed t = `Close (current_read_operation t); ;; +let reader_upgraded t = + Alcotest.check read_operation "Reader is upgraded" + `Upgrade (current_read_operation t); +;; + (* Checks that the [len] prefixes of expected and the write match, and returns the rest. *) let write_partial_string ?(msg="output written") t expected len = @@ -271,6 +282,11 @@ let writer_closed ?(unread = 0) t = (`Close unread) (current_write_operation t); ;; +let writer_upgraded t = + Alcotest.check write_operation "Writer is upgraded" + `Upgrade (current_write_operation t); +;; + let connection_is_shutdown t = reader_closed t; writer_closed t; @@ -321,6 +337,16 @@ let streaming_handler ?(flush=false) response writes reqd = write (); ;; +let capture_handler () = + let fail _ = failwith "Captured handler was not invoked" in + let capture = ref fail in + let respond reqd f = + capture := fail; + f reqd + in + capture, (fun reqd -> capture := respond reqd) +;; + let synchronous_raise reqd = Reqd.report_exn reqd (Failure "caught this exception") ;; @@ -1096,6 +1122,90 @@ let test_flush_response_before_shutdown () = connection_is_shutdown t); ;; +let test_upgrade () = + let headers = Headers.upgrade "foo" in + let request_handler reqd = Reqd.respond_with_upgrade reqd headers in + let t = create request_handler in + read_request t (Request.create `GET "/" ~headers); + reader_upgraded t; + write_response t (Response.create `Switching_protocols ~headers); + writer_upgraded t; +;; + +let test_upgrade_where_server_does_not_upgrade () = + let respond, handler = capture_handler () in + let t = create handler in + read_request t (Request.create `GET "/" ~headers:(Headers.upgrade "foo")); + (* At this point, we don't know if the response handler will call respond_with_upgrade + or not. So we pause the reader until that is determined. *) + reader_yielded t; + + (* Now pretend the user doesn't want to do the upgrade and make sure we close the + connection *) + !respond (fun reqd -> + let response = Response.create `Bad_request ~headers:(Headers.encoding_fixed 0) in + Reqd.respond_with_string reqd response ""; + write_response t response); + + (* The connection is left healthy and can be used for more requests *) + read_request t (Request.create `GET "/" ~headers:(Headers.encoding_fixed 0)); + !respond (fun reqd -> + let response = Response.create `OK ~headers:(Headers.encoding_fixed 0) in + Reqd.respond_with_string reqd response ""; + write_response t response); +;; + +let test_upgrade_with_initial_data () = + let headers = Headers.upgrade "foo" in + let request_handler reqd = Reqd.respond_with_upgrade reqd headers in + let t = create request_handler in + let payload = request_to_string (Request.create `GET "/" ~headers) ^ "foo" in + let c = feed_string t payload in + Alcotest.(check int) "read consumes headers" 53 c; + reader_upgraded t; + write_response t (Response.create `Switching_protocols ~headers); + writer_upgraded t; +;; + +let test_upgrade_with_bad_body_length () = + let headers = Headers.upgrade "foo" in + let request_handler reqd = Reqd.respond_with_upgrade reqd headers in + let t = create request_handler in + read_request t + (Request.create `GET "/" ~headers:Headers.(headers @ encoding_fixed 100)); + reader_closed t; + write_response t (Response.create `Bad_request) ~body:"400"; + writer_closed t; +;; + +let test_asynchronous_upgrade () = + let headers = Headers.upgrade "foo" in + let respond, handler = capture_handler () in + let t = create handler in + read_request t (Request.create `GET "/" ~headers); + reader_yielded t; + + !respond (fun reqd -> Reqd.respond_with_upgrade reqd headers); + reader_upgraded t; + write_response t (Response.create `Switching_protocols ~headers); + writer_upgraded t; +;; + +let test_upgrade_interrupted_by_shutdown () = + let headers = Headers.upgrade "foo" in + let respond, handler = capture_handler () in + let t = create handler in + read_request t (Request.create `GET "/" ~headers); + reader_yielded t; + + shutdown t; + (* XXX(dpatti): If we call this, we try to write to the closed writer *) + (* !respond (fun reqd -> Reqd.respond_with_upgrade reqd headers); *) + ignore respond; + reader_closed t; + writer_closed t; +;; + let test_schedule_read_with_data_available () = let response = Response.create `OK in let body = ref None in @@ -1171,5 +1281,11 @@ let tests = ; "shutdown in request handler", `Quick, test_shutdown_in_request_handler ; "shutdown during asynchronous request", `Quick, test_shutdown_during_asynchronous_request ; "flush response before shutdown", `Quick, test_flush_response_before_shutdown + ; "upgrade", `Quick, test_upgrade + ; "upgrade where server does not upgrade", `Quick, test_upgrade_where_server_does_not_upgrade + ; "upgrade with initial data", `Quick, test_upgrade_with_initial_data + ; "upgrade with bad body length", `Quick, test_upgrade_with_bad_body_length + ; "asynchronous upgrade", `Quick, test_asynchronous_upgrade + ; "upgrade interrupted by shutdown", `Quick, test_upgrade_interrupted_by_shutdown ; "schedule read with data available", `Quick, test_schedule_read_with_data_available ] diff --git a/lwt-unix/httpaf_lwt_unix.ml b/lwt-unix/httpaf_lwt_unix.ml index 545fbc6a..eec42234 100644 --- a/lwt-unix/httpaf_lwt_unix.ml +++ b/lwt-unix/httpaf_lwt_unix.ml @@ -106,7 +106,8 @@ let shutdown socket command = module Config = Httpaf.Config module Server = struct - let create_connection_handler ?(config=Config.default) ~request_handler ~error_handler = + let create_connection_handler + ?(config=Config.default) ~request_handler ~upgrade_handler ~error_handler = fun client_addr socket -> let module Server_connection = Httpaf.Server_connection in let connection = @@ -118,6 +119,27 @@ module Server = struct let read_buffer = Buffer.create config.read_buffer_size in let read_loop_exited, notify_read_loop_exited = Lwt.wait () in + let write_loop_exited, notify_write_loop_exited = Lwt.wait () in + + let upgrade_read, notify_upgrade_read = Lwt.wait () in + let upgrade_write, notify_upgrade_write = Lwt.wait () in + Lwt.async (fun () -> + upgrade_read + >>= fun () -> + upgrade_write + >>= fun () -> + match upgrade_handler with + | None -> Lwt.fail_with "HTTP upgrades not supported" + | Some upgrade_handler -> + upgrade_handler client_addr socket + >>= fun () -> + if (Lwt_unix.state socket = Lwt_unix.Closed) + then Lwt.return_unit + else Lwt_unix.close socket + >>= fun () -> + Lwt.wakeup_later notify_read_loop_exited (); + Lwt.wakeup_later notify_write_loop_exited (); + Lwt.return_unit); let rec read_loop () = let rec read_loop_step () = @@ -140,6 +162,10 @@ module Server = struct Server_connection.yield_reader connection read_loop; Lwt.return_unit + | `Upgrade -> + Lwt.wakeup_later notify_upgrade_read (); + Lwt.return_unit + | `Close -> Lwt.wakeup_later notify_read_loop_exited (); if not (Lwt_unix.state socket = Lwt_unix.Closed) then begin @@ -158,7 +184,6 @@ module Server = struct let writev = Faraday_lwt_unix.writev_of_fd socket in - let write_loop_exited, notify_write_loop_exited = Lwt.wait () in let rec write_loop () = let rec write_loop_step () = @@ -172,6 +197,10 @@ module Server = struct Server_connection.yield_writer connection write_loop; Lwt.return_unit + | `Upgrade -> + Lwt.wakeup_later notify_upgrade_write (); + Lwt.return_unit + | `Close _ -> Lwt.wakeup_later notify_write_loop_exited (); if not (Lwt_unix.state socket = Lwt_unix.Closed) then begin diff --git a/lwt-unix/httpaf_lwt_unix.mli b/lwt-unix/httpaf_lwt_unix.mli index 87cd54bf..d361ea91 100644 --- a/lwt-unix/httpaf_lwt_unix.mli +++ b/lwt-unix/httpaf_lwt_unix.mli @@ -42,6 +42,7 @@ module Server : sig val create_connection_handler : ?config : Config.t -> request_handler : (Unix.sockaddr -> Server_connection.request_handler) + -> upgrade_handler : (Unix.sockaddr -> Lwt_unix.file_descr -> unit Lwt.t) option -> error_handler : (Unix.sockaddr -> Server_connection.error_handler) -> Unix.sockaddr -> Lwt_unix.file_descr