Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Direct-style websockets with Eio #130

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
_build
**/*.merlin
*.install
*.install
.vscode
8 changes: 4 additions & 4 deletions async/dune
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
(library
(name websocket_async)
(public_name websocket-async)
(modules websocket_async)
(optional)
(libraries websocket logs-async cohttp-async))
(public_name websocket-async)
(modules websocket_async)
(optional)
(libraries websocket logs-async cohttp-async))

(executable
(name wscat)
Expand Down
6 changes: 6 additions & 0 deletions core/websocket.mli
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ val upgrade_present : Cohttp.Header.t -> bool

exception Protocol_error of string

val proto_error : ('b, Format.formatter, unit, 'a) format4 -> 'b

module Rng : sig
val init : ?state:Random.State.t -> unit -> int -> string
(** [init ?state ()] is a function that returns a string of random
Expand All @@ -40,6 +42,9 @@ module Frame : sig
| Nonctrl of int

val to_string : t -> string
val to_enum : t -> int
val of_enum : int -> t
val is_ctrl : t -> bool
val pp : Format.formatter -> t -> unit
end

Expand All @@ -57,6 +62,7 @@ module Frame : sig
t

val close : int -> t
val of_bytes : ?opcode:Opcode.t -> ?extension:int -> ?final:bool -> bytes -> t
end

val check_origin :
Expand Down
1 change: 1 addition & 0 deletions dune
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
(vendored_dirs ocaml-cohttp)
25 changes: 25 additions & 0 deletions dune-project
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,28 @@
(websocket (= :version))
(lwt_log (>= 1.1.1))
(cohttp-lwt-unix (>= 5.0.0))))

(package
(name websocket-eio)
(synopsis "Websocket library (Eio)")
(description
"\| The WebSocket Protocol enables two-way communication between a client
"\| running untrusted code in a controlled environment to a remote host
"\| that has opted-in to communications from that code.
"\|
"\| The security model used for this is the origin-based security model
"\| commonly used by web browsers. The protocol consists of an opening
"\| handshake followed by basic message framing, layered over TCP.
"\|
"\| The goal of this technology is to provide a mechanism for
"\| browser-based applications that need two-way communication with
"\| servers that does not rely on opening multiple HTTP connections (e.g.,
"\| using XMLHttpRequest or <iframe>s and long polling).
)
(tags (org:mirage org:xapi-project))
(depends
(ocaml (>= 5.0.0))
(websocket (= :version))
eio
cohttp-eio
(eio_main :with-test)))
4 changes: 4 additions & 0 deletions eio/dune
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
(library
(name websocket_eio)
(public_name websocket-eio)
(libraries websocket cohttp-eio))
163 changes: 163 additions & 0 deletions eio/io.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
open Websocket
open Astring
open Eio

type mode = Client of (int -> string) | Server

let is_client mode = mode <> Server

let xor mask msg =
for i = 0 to Bytes.length msg - 1 do
(* masking msg to send *)
Bytes.set msg i
Char.(to_int mask.[i mod 4] lxor to_int (Bytes.get msg i) |> of_byte)
done

let is_bit_set idx v = (v lsr idx) land 1 = 1
let set_bit v idx b = if b then v lor (1 lsl idx) else v land lnot (1 lsl idx)
let int_value shift len v = (v lsr shift) land ((1 lsl len) - 1)

let read_exactly src remaining =
patricoferris marked this conversation as resolved.
Show resolved Hide resolved
try
Some (Buf_read.take remaining src)
with End_of_file -> None

let read_uint16 ic =
match read_exactly ic 2 with
| None -> None
| Some s -> Some (EndianString.BigEndian.get_uint16 s 0)

let read_int64 ic =
match read_exactly ic 8 with
| None -> None
| Some s -> Some (Int64.to_int @@ EndianString.BigEndian.get_int64 s 0)

let write_frame_to_buf ~mode buf fr =
let open Frame in
let content = Bytes.unsafe_of_string fr.content in
let len = Bytes.length content in
let opcode = Opcode.to_enum fr.opcode in
let payload_len =
match len with
| n when n < 126 -> len
| n when n < 1 lsl 16 -> 126
| _ -> 127 in
let hdr = set_bit 0 15 fr.final in
(* We do not support extensions for now *)
let hdr = hdr lor (opcode lsl 8) in
let hdr = set_bit hdr 7 (is_client mode) in
let hdr = hdr lor payload_len in
(* Payload len is guaranteed to fit in 7 bits *)
Buf_write.BE.uint16 buf hdr;
( match len with
| n when n < 126 -> ()
| n when n < 1 lsl 16 ->
Buf_write.BE.uint16 buf n
| n ->
Buf_write.BE.uint64 buf Int64.(of_int n);
);
( match mode with
| Server -> ()
| Client random_string ->
let mask = random_string 4 in
Buf_write.string buf mask ;
if len > 0 then xor mask content ) ;
Buf_write.bytes buf content

let close_with_code mode dst code =
write_frame_to_buf ~mode dst @@ Frame.close code

let read_frame ic oc mode hdr =
let hdr_part1 = EndianString.BigEndian.get_int8 hdr 0 in
let hdr_part2 = EndianString.BigEndian.get_int8 hdr 1 in
let final = is_bit_set 7 hdr_part1 in
let extension = int_value 4 3 hdr_part1 in
let opcode = int_value 0 4 hdr_part1 in
let frame_masked = is_bit_set 7 hdr_part2 in
let length = int_value 0 7 hdr_part2 in
let opcode = Frame.Opcode.of_enum opcode in
let payload_len =
match length with
| i when i < 126 -> i
| 126 -> ( match read_uint16 ic with Some i -> i | None -> -1 )
| 127 -> ( match read_int64 ic with Some i -> i | None -> -1 )
| _ -> -1 in
if payload_len = -1 then proto_error "payload len = %d" length
else if extension <> 0 then (
close_with_code mode oc 1002 ;
proto_error "unsupported extension" )
else if Frame.Opcode.is_ctrl opcode && payload_len > 125 then (
close_with_code mode oc 1002 ;
proto_error "control frame too big" )
else
let mask =
if frame_masked then (
match read_exactly ic 4 with
| None -> proto_error "could not read mask"
| Some mask -> mask )
else String.empty in
if payload_len = 0 then Frame.create ~opcode ~extension ~final ()
else (
match read_exactly ic payload_len with
| None -> proto_error "could not read payload (len=%d)" payload_len
| Some payload ->
let payload = Bytes.unsafe_of_string payload in
if frame_masked then xor mask payload ;
let frame = Frame.of_bytes ~opcode ~extension ~final payload in
frame )

let make_read_frame ~mode ic oc () =
match read_exactly ic 2 with
| None -> raise End_of_file
| Some hdr -> read_frame ic oc mode hdr

module Connected_client = struct
type t =
{ buffer: Buf_write.t;
endp: Conduit.endp;
ic: Buf_read.t;
http_request: Cohttp.Request.t;
standard_frame_replies: bool;
read_frame: unit -> Frame.t }

let source {endp; _} = endp

let create http_request endp ic oc =
let read_frame = make_read_frame ~mode:Server ic oc in
{ buffer = oc;
endp;
ic;
http_request;
standard_frame_replies = false;
read_frame }

let send {buffer; _} frame =
write_frame_to_buf ~mode:Server buffer frame

let send_multiple {buffer; _} frames =
List.iter (write_frame_to_buf ~mode:Server buffer) frames

let standard_recv t =
let fr = t.read_frame () in
match fr.Frame.opcode with
| Frame.Opcode.Ping ->
send t @@ Frame.create ~opcode:Frame.Opcode.Pong () ;
fr
| Frame.Opcode.Close ->
(* Immediately echo and pass this last message to the user *)
if String.length fr.Frame.content >= 2 then
send t
@@ Frame.create ~opcode:Frame.Opcode.Close
~content:
String.(sub ~start:0 ~stop:2 fr.Frame.content |> Sub.to_string)
()
else send t @@ Frame.close 1000 ;
fr
| _ -> fr

let recv t =
if t.standard_frame_replies then standard_recv t else t.read_frame ()

let http_request {http_request; _} = http_request
let make_standard t = {t with standard_frame_replies= true}
end
62 changes: 62 additions & 0 deletions eio/websocket_eio.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
(*
* Copyright (c) 2016-2018 Maciej Wos <[email protected]>
* Copyright (c) 2012-2018 Vincent Bernardoff <[email protected]>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*
*)
open Websocket
module Ws_io = Io

let send_frames stream (oc : Eio.Buf_write.t) =
let rec send_frame stream =
let fr = Eio.Stream.take stream in
Ws_io.write_frame_to_buf ~mode:Server oc fr ;
send_frame stream in
send_frame stream

let read_frames ic oc handler_fn : unit =
let read_frame = Ws_io.make_read_frame ~mode:Server ic oc in
let rec inner () =
handler_fn @@ read_frame () ;
inner () in
inner ()

let upgrade_connection (request : Http.Request.t) incoming_handler =
let request = request in
let headers = Http.Request.headers request in
let key =
match Http.Header.get headers "sec-websocket-key" with
| None ->
invalid_arg "upgrade_connection: missing header `sec-websocket-key`"
| Some key -> key in
let hash = b64_encoded_sha1sum (key ^ websocket_uuid) in
let response_headers =
Http.Header.of_list
[ ("Upgrade", "websocket"); ("Connection", "Upgrade");
("Sec-WebSocket-Accept", hash) ] in
let frames_out_stream = Eio.Stream.create max_int in
let frames_out_fn v = Eio.Stream.add frames_out_stream v in
let f (ic : Eio.Buf_read.t) (oc : Eio.Buf_write.t) =
Eio.Fiber.both
(* output: data for the client is written to the output
* channel of the tcp connection *)
(fun () -> send_frames frames_out_stream oc )
(* input: data from the client is read from the input channel
* of the tcp connection; pass it to handler function *)
(fun () -> read_frames ic oc incoming_handler ) in
let resp : Cohttp_eio.Server.response_action =
let http_response = Http.Response.make ~status:`Switching_protocols ~version:`HTTP_1_1 ~headers:response_headers () in
`Expert (http_response, f)
in
(resp, frames_out_fn)
12 changes: 12 additions & 0 deletions eio/websocket_eio.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
open Websocket

val upgrade_connection :
Http.Request.t ->
(Frame.t -> unit) ->
Cohttp_eio.Server.response_action * (Frame.t -> unit)
(** [upgrade_connection req incoming_handler] takes [req], a
connection request, and [incoming_handler], a function that will
process incoming websocket frames, and returns ([response_action],
[push_frame]) where [response_action] is used to produce a
{!Cohttp_lwt.Server.t} and [push_frame] is used to send websocket
frames to the client. *)
7 changes: 6 additions & 1 deletion test/dune
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,9 @@
(executable
(name upgrade_connection)
(modules upgrade_connection)
(libraries logs.fmt logs.lwt websocket_cohttp_lwt))
(libraries lwt.unix logs.fmt logs.lwt websocket_cohttp_lwt))

(executable
(name eio_upgrade_connection)
(modules eio_upgrade_connection)
(libraries eio_main websocket_eio))
Loading