From 16b00fa31e521fdc92dd1f33fcbd5c506391689b Mon Sep 17 00:00:00 2001 From: Thomas Gazagnaire Date: Mon, 1 Feb 2021 16:21:48 +0100 Subject: [PATCH] conduit-mirage: simplify the API Do not mix first-class modules and functors anywmore as it makes the API quite complicated for not much gains: the code is generated by the mirage tools anyway. --- conduit-mirage.opam | 1 + src/conduit-mirage/conduit_mirage.ml | 379 +++++++------------ src/conduit-mirage/conduit_mirage.mli | 159 +++----- src/conduit-mirage/resolver_mirage.ml | 208 +++++----- src/conduit-mirage/resolver_mirage.mli | 31 +- tests/conduit-mirage/http-fetch/unikernel.ml | 4 - tests/conduit-mirage/simple/dune | 2 +- tests/conduit-mirage/simple/test.ml | 16 +- 8 files changed, 330 insertions(+), 470 deletions(-) diff --git a/conduit-mirage.opam b/conduit-mirage.opam index 8472c69f..074834e7 100644 --- a/conduit-mirage.opam +++ b/conduit-mirage.opam @@ -27,6 +27,7 @@ depends: [ "ca-certs-nss" "ipaddr" {>= "3.0.0"} "ipaddr-sexp" + "tcpip" {with-test} ] conflicts: [ "mirage-conduit" diff --git a/src/conduit-mirage/conduit_mirage.ml b/src/conduit-mirage/conduit_mirage.ml index 38de3af1..c8bea2c2 100644 --- a/src/conduit-mirage/conduit_mirage.ml +++ b/src/conduit-mirage/conduit_mirage.ml @@ -31,174 +31,69 @@ let err_vchan_not_supported = fail "%s: VCHAN is not supported" let err_unknown = fail "%s: unknown endpoint type" let err_ipv6 = fail "%s: IPv6 is not supported" -module Flow = struct - type error = [ `Msg of string ] - type write_error = [ Mirage_flow.write_error | error ] - - let pp_error ppf (`Msg s) = Fmt.string ppf s - - let pp_write_error ppf = function - | #Mirage_flow.write_error as e -> Mirage_flow.pp_write_error ppf e - | #error as e -> pp_error ppf e - - open Mirage_flow_combinators - - type flow = Flow : (module CONCRETE with type flow = 'a) * 'a -> flow - - let create (type a) (module M : Mirage_flow.S with type flow = a) t = - let m = (module Concrete (M) : CONCRETE with type flow = a) in - Flow (m, t) - - let read (Flow ((module F), flow)) = F.read flow - let write (Flow ((module F), flow)) b = F.write flow b - let writev (Flow ((module F), flow)) b = F.writev flow b - let close (Flow ((module F), flow)) = F.close flow -end - -type callback = Flow.flow -> unit Lwt.t - -module type Handler = sig - type t - (** Runtime handler *) - - type client [@@deriving sexp] - type server [@@deriving sexp] - - val connect : t -> client -> Flow.flow Lwt.t - val listen : t -> server -> callback -> unit Lwt.t -end - -type tcp_client = [ `TCP of Ipaddr_sexp.t * int ] [@@deriving sexp] -type tcp_server = [ `TCP of int ] [@@deriving sexp] -type 'a stackv4 = (module Mirage_stack.V4 with type t = 'a) - -let stackv4 x = x - -module type VCHAN = Vchan.S.ENDPOINT with type port = Vchan.Port.t -module type XS = Xs_client_lwt.S - -type vchan_client = - [ `Vchan of - [ `Direct of int * Vchan.Port.t (** domain id, port *) - | `Domain_socket of string * Vchan.Port.t (** Vchan Xen domain socket *) - ] ] -[@@deriving sexp] - -type vchan_server = - [ `Vchan of - [ `Direct of int * Vchan.Port.t (** domain id, port *) - | `Domain_socket (** Vchan Xen domain socket *) ] ] +let err_not_supported = function + | `TLS _ -> err_tls_not_supported + | `TCP _ -> err_tcp_not_supported + | `Vchan _ -> err_vchan_not_supported + +type client = + [ `TCP of Ipaddr_sexp.t * int + | `TLS of Tls.Config.client * client + | `Vchan of + [ `Direct of int * Vchan.Port.t | `Domain_socket of string * Vchan.Port.t ] + ] [@@deriving sexp] -type vchan = (module VCHAN) -type xs = (module XS) - -let vchan x = x -let xs x = x - -type 'a tls_client = [ `TLS of Tls.Config.client * 'a ] [@@deriving sexp] -type 'a tls_server = [ `TLS of Tls.Config.server * 'a ] [@@deriving sexp] - -type client = [ tcp_client | vchan_client | client tls_client ] +type server = + [ `TCP of int + | `TLS of Tls.Config.server * server + | `Vchan of [ `Direct of int * Vchan.Port.t | `Domain_socket ] ] [@@deriving sexp] -type server = [ tcp_server | vchan_server | server tls_server ] -[@@deriving sexp] +module type S = sig + type t + type flow -type tls_client' = client tls_client [@@deriving sexp] -type tls_server' = server tls_server [@@deriving sexp] + module Flow : Mirage_flow.S with type flow = flow -type ('c, 's) handler = - | S : - (module Handler with type t = 'a and type client = 'c and type server = 's) - * 'a - -> ('c, 's) handler + val connect : t -> client -> flow Lwt.t + val listen : t -> server -> (flow -> unit Lwt.t) -> unit Lwt.t +end +(* TCP *) let tcp_client i p = Lwt.return (`TCP (i, p)) let tcp_server _ p = Lwt.return (`TCP p) -type t = { - tcp : (tcp_client, tcp_server) handler option; - tls : (tls_client', tls_server') handler option; - vchan : (vchan_client, vchan_server) handler option; -} - -let empty = { tcp = None; tls = None; vchan = None } - -let connect t (c : client) = - match c with - | `TCP _ as x -> ( - match t.tcp with - | None -> err_tcp_not_supported "connect" - | Some (S ((module S), t)) -> S.connect t x) - | `Vchan _ as x -> ( - match t.vchan with - | None -> err_vchan_not_supported "connect" - | Some (S ((module S), t)) -> S.connect t x) - | `TLS _ as x -> ( - match t.tls with - | None -> err_tls_not_supported "connect" - | Some (S ((module S), t)) -> S.connect t x) - -let listen t (s : server) f = - match s with - | `TCP _ as x -> ( - match t.tcp with - | None -> err_tcp_not_supported "listen" - | Some (S ((module S), t)) -> S.listen t x f) - | `Vchan _ as x -> ( - match t.vchan with - | None -> err_vchan_not_supported "listen" - | Some (S ((module S), t)) -> S.listen t x f) - | `TLS _ as x -> ( - match t.tls with - | None -> err_tls_not_supported "listen" - | Some (S ((module S), t)) -> S.listen t x f) - -(******************************************************************************) -(* Implementation of handlers *) -(******************************************************************************) - -(* TCP *) - module TCP (S : Mirage_stack.V4) = struct + module Flow = S.TCPV4 + + type flow = Flow.flow type t = S.t - type client = tcp_client [@@deriving sexp] - type server = tcp_server [@@deriving sexp] let err_tcp e = Lwt.fail @@ Failure (Format.asprintf "TCP connection failed: %a" S.TCPV4.pp_error e) - let connect t (`TCP (ip, port) : client) = - match Ipaddr.to_v4 ip with - | None -> err_ipv6 "connect" - | Some ip -> ( - S.TCPV4.create_connection (S.tcpv4 t) (ip, port) >>= function - | Error e -> err_tcp e - | Ok flow -> - let flow = Flow.create (module S.TCPV4) flow in - Lwt.return flow) - - let listen t (`TCP port : server) fn = - let s, _u = Lwt.task () in - S.listen_tcpv4 t ~port (fun flow -> - let f = Flow.create (module S.TCPV4) flow in - fn f); - s -end - -module With_tcp (S : Mirage_stack.V4) = struct - module M = TCP (S) - - let handler stack = Lwt.return (S ((module M), stack)) - let connect stack t = handler stack >|= fun x -> { t with tcp = Some x } + let connect (t : t) (c : client) = + match c with + | `TCP (ip, port) -> ( + match Ipaddr.to_v4 ip with + | None -> err_ipv6 "connect" + | Some ip -> ( + S.TCPV4.create_connection (S.tcpv4 t) (ip, port) >>= function + | Error e -> err_tcp e + | Ok flow -> Lwt.return flow)) + | _ -> err_not_supported c "connect" + + let listen (t : t) (s : server) fn = + match s with + | `TCP port -> + let s, _u = Lwt.task () in + S.listen_tcpv4 t ~port (fun flow -> fn flow); + s + | _ -> err_not_supported s "listen" end -let with_tcp (type t) t (module S : Mirage_stack.V4 with type t = t) stack = - let module M = With_tcp (S) in - M.connect stack t - (* VCHAN *) let err_vchan_port = fail "%s: invalid Vchan port" @@ -217,139 +112,143 @@ let vchan_server = function | `Vchan_direct (i, p) -> port p >|= fun p -> `Vchan (`Direct (i, p)) | `Vchan_domain_socket _ -> Lwt.return (`Vchan `Domain_socket) -module Vchan (Xs : Xs_client_lwt.S) (V : VCHAN) = struct +module Vchan + (Xs : Xs_client_lwt.S) + (V : Vchan.S.ENDPOINT with type port = Vchan.Port.t) = +struct + module Flow = V module XS = Conduit_xenstore.Make (Xs) + type flow = Flow.flow type t = XS.t - type client = vchan_client [@@deriving sexp] - type server = vchan_server [@@deriving sexp] let register = XS.register - let rec connect t (c : vchan_client) = + let rec connect (t : t) (c : client) = match c with | `Vchan (`Domain_socket (uid, port)) -> XS.connect t ~remote_name:uid ~port >>= fun endp -> - connect t (`Vchan endp :> vchan_client) - | `Vchan (`Direct (domid, port)) -> - V.client ~domid ~port () >>= fun flow -> - Lwt.return (Flow.create (module V) flow) - - let listen (t : t) (server : vchan_server) fn = - match server with - | `Vchan (`Direct (domid, port)) -> - V.server ~domid ~port () >>= fun t -> fn (Flow.create (module V) t) + connect t (`Vchan endp :> client) + | `Vchan (`Direct (domid, port)) -> V.client ~domid ~port () + | _ -> err_not_supported c "connect" + + let listen (t : t) (s : server) fn = + match s with + | `Vchan (`Direct (domid, port)) -> V.server ~domid ~port () >>= fn | `Vchan `Domain_socket -> XS.listen t >>= fun conns -> Lwt_stream.iter_p - (function - | `Direct (domid, port) -> - V.server ~domid ~port () >>= fun t -> - fn (Flow.create (module V) t)) + (function `Direct (domid, port) -> V.server ~domid ~port () >>= fn) conns + | _ -> err_not_supported s "listen" end -let mk_vchan (module X : XS) (module V : VCHAN) t = - let module V = Vchan (X) (V) in - V.register t >|= fun t -> S ((module V), t) +(* TLS *) -let with_vchan t x y z = mk_vchan x y z >|= fun x -> { t with vchan = Some x } +let tls_client ~authenticator x = `TLS (Tls.Config.client ~authenticator (), x) +let tls_server ~authenticator x = `TLS (Tls.Config.server ~authenticator (), x) -(* TLS *) +module TLS (S : S) = struct + module TLS = Tls_mirage.Make (S.Flow) -let server_of_bytes str = Tls.Config.server_of_sexp (Sexplib.Sexp.of_string str) -let tls_server s x = Lwt.return (`TLS (server_of_bytes s, x)) + type flow = TLS of TLS.flow | Clear of S.flow + type t = S.t -module TLS = struct - module TLS = Tls_mirage.Make (Flow) + module Flow = struct + type nonrec flow = flow + type error = [ `Flow of S.Flow.error | `TLS of TLS.error ] - let err_flow_write m e = fail "%s: %a" m TLS.pp_write_error e + type write_error = + [ Mirage_flow.write_error + | `Flow of S.Flow.write_error + | `TLS of TLS.write_error ] - type x = t - type t = x - type client = tls_client' [@@deriving sexp] - type server = tls_server' [@@deriving sexp] - - let connect (t : t) (`TLS (c, x) : client) = - connect t x >>= fun flow -> - TLS.client_of_flow c flow >>= function - | Error e -> err_flow_write "connect" e - | Ok flow -> Lwt.return (Flow.create (module TLS) flow) - - let listen (t : t) (`TLS (c, x) : server) fn = - listen t x (fun flow -> - TLS.server_of_flow c flow >>= function - | Error e -> err_flow_write "listen" e - | Ok flow -> fn (Flow.create (module TLS) flow)) -end + let pp_error ppf = function + | `Flow e -> S.Flow.pp_error ppf e + | `TLS e -> TLS.pp_error ppf e -let tls t = Lwt.return (S ((module TLS), t)) -let with_tls t = tls t >|= fun x -> { t with tls = Some x } + let pp_write_error ppf = function + | #Mirage_flow.write_error as e -> Mirage_flow.pp_write_error ppf e + | `Flow e -> S.Flow.pp_write_error ppf e + | `TLS e -> TLS.pp_write_error ppf e -type conduit = t + let tls_err = function Ok _ as x -> x | Error e -> Error (`TLS e) + let flow_err = function Ok _ as x -> x | Error e -> Error (`Flow e) -module type S = sig - type t = conduit + let tls_write_err = function + | Ok _ as x -> x + | Error `Closed as x -> x + | Error e -> Error (`TLS e) + + let flow_write_err = function + | Ok _ as x -> x + | Error `Closed as x -> x + | Error e -> Error (`Flow e) + + let read = function + | TLS f -> TLS.read f >|= tls_err + | Clear f -> S.Flow.read f >|= flow_err + + let write t x = + match t with + | TLS f -> TLS.write f x >|= tls_write_err + | Clear f -> S.Flow.write f x >|= flow_write_err - val empty : t + let writev t x = + match t with + | TLS f -> TLS.writev f x >|= tls_err + | Clear f -> S.Flow.writev f x >|= flow_err - module With_tcp (S : Mirage_stack.V4) : sig - val connect : S.t -> t -> t Lwt.t + let close = function TLS f -> TLS.close f | Clear f -> S.Flow.close f end - val with_tcp : t -> 'a stackv4 -> 'a -> t Lwt.t - val with_tls : t -> t Lwt.t - val with_vchan : t -> xs -> vchan -> string -> t Lwt.t - val connect : t -> client -> Flow.flow Lwt.t - val listen : t -> server -> callback -> unit Lwt.t + let err_flow_write m e = fail "%s: %a" m TLS.pp_write_error e + + let connect (t : t) (c : client) = + match c with + | `TLS (c, x) -> ( + S.connect t x >>= fun flow -> + TLS.client_of_flow c flow >>= function + | Error e -> err_flow_write "connect" e + | Ok flow -> Lwt.return (TLS flow)) + | _ -> S.connect t c >|= fun t -> Clear t + + let listen (t : t) (s : server) fn = + match s with + | `TLS (c, x) -> + S.listen t x (fun flow -> + TLS.server_of_flow c flow >>= function + | Error e -> err_flow_write "listen" e + | Ok flow -> fn (TLS flow)) + | _ -> S.listen t s (fun f -> fn (Clear f)) end -module Client (P : Mirage_clock.PCLOCK) = struct +module Endpoint (P : Mirage_clock.PCLOCK) = struct module Ca_certs = Ca_certs_nss.Make (P) - let default_authenticator = + let nss_authenticator = match Ca_certs.authenticator () with | Ok a -> a | Error (`Msg msg) -> failwith msg - let tls_client ~authenticator x = `TLS (Tls.Config.client ~authenticator (), x) - - let rec resolve ?(tls_authenticator = default_authenticator) e = + let rec client ?(tls_authenticator = nss_authenticator) e = match e with | `TCP (x, y) -> tcp_client x y | `Unix_domain_socket _ -> err_domain_sockets_not_supported "client" | (`Vchan_direct _ | `Vchan_domain_socket _) as x -> vchan_client x - | `TLS (_, y) -> - resolve ~tls_authenticator y + | `TLS (_host, y) -> + client ~tls_authenticator y >|= tls_client ~authenticator:tls_authenticator | `Unknown s -> err_unknown s -end - -let rec server (e : Conduit.endp) : server Lwt.t = - match e with - | `TCP (x, y) -> tcp_server x y - | `Unix_domain_socket _ -> err_domain_sockets_not_supported "server" - | (`Vchan_direct _ | `Vchan_domain_socket _) as x -> vchan_server x - | `TLS (x, y) -> server y >>= fun s -> tls_server x s - | `Unknown s -> err_unknown s - -module Context - (R : Mirage_random.S) - (T : Mirage_time.S) - (C : Mirage_clock.MCLOCK) - (S : Mirage_stack.V4) = -struct - type t = Resolver_lwt.t * conduit - module RES = Resolver_mirage.Make_with_stack (R) (T) (C) (S) + let ok_authenticator ~host:_ _ = Ok None - let conduit = empty - let stackv4 = stackv4 (module S : Mirage_stack.V4 with type t = S.t) - - let create ?(tls = false) stack = - let res = Resolver_lwt.init () in - RES.R.register ~stack res; - with_tcp conduit stackv4 stack >>= fun conduit -> - if tls then with_tls conduit >|= fun conduit -> (res, conduit) - else Lwt.return (res, conduit) + let rec server ?(tls_authenticator = ok_authenticator) e = + match e with + | `TCP (x, y) -> tcp_server x y + | `Unix_domain_socket _ -> err_domain_sockets_not_supported "server" + | (`Vchan_direct _ | `Vchan_domain_socket _) as x -> vchan_server x + | `TLS (_host, y) -> + server y >|= tls_server ~authenticator:tls_authenticator + | `Unknown s -> err_unknown s end diff --git a/src/conduit-mirage/conduit_mirage.mli b/src/conduit-mirage/conduit_mirage.mli index d3baaab3..70852894 100644 --- a/src/conduit-mirage/conduit_mirage.mli +++ b/src/conduit-mirage/conduit_mirage.mli @@ -19,136 +19,87 @@ (** Functorial connection establishment interface that is compatible with the Mirage libraries. *) -module Flow : Mirage_flow.S -(** Dynamic flows. *) - -type callback = Flow.flow -> unit Lwt.t -(** The type for callback values. *) - -module type Handler = sig - (** The signature for runtime handlers *) - - type t - (** The type for runtime handlers. *) - - type client [@@deriving sexp] - (** The type for client configuration values. *) - - type server [@@deriving sexp] - (** The type for server configuration values. *) - - val connect : t -> client -> Flow.flow Lwt.t - (** Connect a conduit using client configuration. *) - - val listen : t -> server -> callback -> unit Lwt.t - (** Listen to a conduit using a server configuration. *) -end - -(** {2 TCP} *) - -(** The type for client connections. *) - -type tcp_client = [ `TCP of Ipaddr.t * int ] -(** address and destination port *) - -and tcp_server = [ `TCP of int ] -(** listening port *) - -type 'a stackv4 - -val stackv4 : (module Mirage_stack.V4 with type t = 'a) -> 'a stackv4 - -(** {2 VCHAN} *) - -type vchan_client = - [ `Vchan of +type client = + [ `TCP of Ipaddr.t * int (** address and destination port *) + | `TLS of Tls.Config.client * client + | `Vchan of [ `Direct of int * Vchan.Port.t (** domain id, port *) | `Domain_socket of string * Vchan.Port.t (** Vchan Xen domain socket *) ] ] +[@@deriving sexp] +(** The type for client configuration values. *) -type vchan_server = - [ `Vchan of +type server = + [ `TCP of int (** listening port *) + | `TLS of Tls.Config.server * server + | `Vchan of [ `Direct of int * Vchan.Port.t (** domain id, port *) | `Domain_socket (** Vchan Xen domain socket *) ] ] +[@@deriving sexp] +(** The type for server configuration values. *) -module type VCHAN = Vchan.S.ENDPOINT with type port = Vchan.Port.t -module type XS = Xs_client_lwt.S - -type vchan -type xs - -val vchan : (module VCHAN) -> vchan -val xs : (module XS) -> xs +module Endpoint (P : Mirage_clock.PCLOCK) : sig + val nss_authenticator : X509.Authenticator.t + (** [nss_authenticator] is the validator using the + {{:https://github.com/mirage/ca-certs-nss} trust anchors extracted from + Mozilla's NSS}. *) -(** {2 TLS} *) + val ok_authenticator : X509.Authenticator.t + (** [ok_authenticator] is the validator which accepts all certificates. *) -type 'a tls_client = [ `TLS of Tls.Config.client * 'a ] -type 'a tls_server = [ `TLS of Tls.Config.server * 'a ] - -type client = [ tcp_client | vchan_client | client tls_client ] -[@@deriving sexp] -(** The type for client configuration values. *) + val client : + ?tls_authenticator:X509.Authenticator.t -> Conduit.endp -> client Lwt.t + (** [client] resolves a conduit endpoint into a client configuration. -type server = [ tcp_server | vchan_server | server tls_server ] -[@@deriving sexp] -(** The type for server configuration values. *) + The certificate is validated using [tls_authenticator]. By default, it is + [nss_authenticator] *) -module Client (P : Mirage_clock.PCLOCK) : sig - val resolve : - ?tls_authenticator:X509.Authenticator.t -> Conduit.endp -> client Lwt.t - (** Resolve a conduit endpoint into a client configuration. + val server : + ?tls_authenticator:X509.Authenticator.t -> Conduit.endp -> server Lwt.t + (** [server] resolves a confuit endpoint into a server configuration. - The certificate is validated using [tls_authenticator]. By default, the - validation is using the {{:https://github.com/mirage/ca-certs-nss} trust - anchors extracted from Mozilla's NSS}. *) + The certificate is validated using [tls_authenticator]. By default, it is + [ok_authenticator]. *) end -val server : Conduit.endp -> server Lwt.t -(** Resolve a confuit endpoint into a server configuration. *) +module type S = sig + (** The signature for conduits *) -type conduit -(** The type for conduit values. *) + type flow + (** The type for networking flows. *) -module type S = sig - (** The signature for Conduit implementations. *) + type t + (** The type for handlers. *) - type t = conduit + module Flow : Mirage_flow.S with type flow = flow + (** The type for flows. *) - val empty : t - (** The empty conduit. *) + val connect : t -> client -> flow Lwt.t + (** Connect a conduit using client configuration. *) - module With_tcp (S : Mirage_stack.V4) : sig - val connect : S.t -> t -> t Lwt.t - end + val listen : t -> server -> (flow -> unit Lwt.t) -> unit Lwt.t + (** Listen to a conduit using a server configuration. *) +end - val with_tcp : t -> 'a stackv4 -> 'a -> t Lwt.t - (** Extend a conduit with an implementation for TCP. *) +(** {2 TCP} *) - val with_tls : t -> t Lwt.t - (** Extend a conduit with an implementation for TLS. *) +module TCP (S : Mirage_stack.V4) : + S with type t = S.t and type flow = S.TCPV4.flow - val with_vchan : t -> xs -> vchan -> string -> t Lwt.t - (** Extend a conduit with an implementation for VCHAN. *) +(** {2 VCHAN} *) - val connect : t -> client -> Flow.flow Lwt.t - (** Connect a conduit using a client configuration value. *) +module Vchan + (X : Xs_client_lwt.S) + (V : Vchan.S.ENDPOINT with type port = Vchan.Port.t) : sig + include S - val listen : t -> server -> callback -> unit Lwt.t - (** Configure a server using a conduit configuration value. *) + val register : string -> t Lwt.t end -include S +(** {2 TLS} *) -(** {2 Context for MirageOS conduit resolvers} *) -module Context - (R : Mirage_random.S) - (T : Mirage_time.S) - (C : Mirage_clock.MCLOCK) - (S : Mirage_stack.V4) : sig - type t = Resolver_lwt.t * conduit - (** The type for contexts of conduit resolvers. *) +module TLS (S : S) : sig + type flow = TLS of Tls_mirage.Make(S.Flow).flow | Clear of S.flow - val create : ?tls:bool -> S.t -> t Lwt.t - (** Create a new context. If [tls] is specified (by defaut, it is not), set-up - the conduit to accept TLS connections. *) + include S with type t = S.t and type flow := flow end diff --git a/src/conduit-mirage/resolver_mirage.ml b/src/conduit-mirage/resolver_mirage.ml index 30fe84b5..4a842fba 100644 --- a/src/conduit-mirage/resolver_mirage.ml +++ b/src/conduit-mirage/resolver_mirage.ml @@ -17,50 +17,14 @@ open Lwt.Infix -let is_tls_service = - (* TODO fill in the blanks. nowhere else to get this information *) - function - | "https" | "imaps" -> true - | _ -> false - -let get_host uri = - match Uri.host uri with - | None -> "localhost" - | Some host -> ( - match Ipaddr.of_string host with - | Ok ip -> Ipaddr.to_string ip - | Error _ -> host) - -let get_port service uri = - match Uri.port uri with None -> service.Resolver.port | Some port -> port - -let static_resolver hosts service uri = - let port = get_port service uri in - try - let fn = Hashtbl.find hosts (get_host uri) in - Lwt.return (fn ~port) - with Not_found -> Lwt.return (`Unknown "name resolution failed") - -let static_service name = - match Uri_services.tcp_port_of_service name with - | [] -> Lwt.return_none - | port :: _ -> - let tls = is_tls_service name in - let svc = { Resolver.name; port; tls } in - Lwt.return (Some svc) - -let static hosts = - let service = static_service in - let rewrites = [ ("", static_resolver hosts) ] in - Resolver_lwt.init ~service ~rewrites () - -let localhost = - let hosts = Hashtbl.create 3 in - Hashtbl.add hosts "localhost" (fun ~port -> - `TCP (Ipaddr.(V4 V4.localhost), port)); - static hosts - -module Make_with_stack +module type S = sig + include Resolver_lwt.S + + val static : (string, port:int -> Conduit.endp) Hashtbl.t -> t + val localhost : t +end + +module Make (R : Mirage_random.S) (T : Mirage_time.S) (C : Mirage_clock.MCLOCK) @@ -68,62 +32,102 @@ module Make_with_stack struct include Resolver_lwt - module R = struct - let vchan_resolver ~tld = - let tld_len = String.length tld in - let get_short_host uri = - let n = get_host uri in - let len = String.length n in - if len > tld_len && String.sub n (len - tld_len) tld_len = tld then - String.sub n 0 (len - tld_len) - else n - in - fun service uri -> - (* Strip the tld from the hostname *) - let remote_name = get_short_host uri in - Printf.printf "vchan_lookup: %s %s -> normalizes to %s\n%!" - (Sexplib.Sexp.to_string_hum (Resolver.sexp_of_service service)) - (Uri.to_string uri) remote_name; - Lwt.return (`Vchan_domain_socket (remote_name, service.Resolver.name)) - - module DNS = Dns_client_mirage.Make (R) (T) (C) (S) - - let dns_stub_resolver dns service uri : Conduit.endp Lwt.t = - let hostn = get_host uri in - let port = get_port service uri in - (match Ipaddr.V4.of_string hostn with - | Ok addr -> Lwt.return (Ok addr) - | Error _ -> ( - match Domain_name.of_string hostn with - | Error (`Msg msg) -> Lwt.return (Error (`Msg msg)) - | Ok domain -> ( - match Domain_name.host domain with - | Error (`Msg msg) -> Lwt.return (Error (`Msg msg)) - | Ok host -> DNS.gethostbyname dns host))) - >|= function - | Error (`Msg err) -> `Unknown ("name resolution failed: " ^ err) - | Ok addr -> `TCP (Ipaddr.V4 addr, port) - - let register ?ns ?(ns_port = 53) ?stack res = - (match stack with - | Some s -> - (* DNS stub resolver *) - let nameserver = - match ns with None -> None | Some ip -> Some (`TCP, (ip, ns_port)) - in - let dns = DNS.create ?nameserver s in - let f = dns_stub_resolver dns in - Resolver_lwt.add_rewrite ~host:"" ~f res - | None -> ()); - let service = Resolver_lwt.(service res ++ static_service) in - Resolver_lwt.set_service ~f:service res; - let vchan_tld = ".xen" in - let vchan_res = vchan_resolver ~tld:vchan_tld in - Resolver_lwt.add_rewrite ~host:vchan_tld ~f:vchan_res res - - let init ?ns ?ns_port ?stack () = - let res = Resolver_lwt.init () in - register ?ns ?ns_port ?stack res; - res - end + let is_tls_service = + (* TODO fill in the blanks. nowhere else to get this information *) + function + | "https" | "imaps" -> true + | _ -> false + + let get_host uri = + match Uri.host uri with + | None -> "localhost" + | Some host -> ( + match Ipaddr.of_string host with + | Ok ip -> Ipaddr.to_string ip + | Error _ -> host) + + let get_port service uri = + match Uri.port uri with None -> service.Resolver.port | Some port -> port + + let static_resolver hosts service uri = + let port = get_port service uri in + try + let fn = Hashtbl.find hosts (get_host uri) in + Lwt.return (fn ~port) + with Not_found -> Lwt.return (`Unknown "name resolution failed") + + let static_service name = + match Uri_services.tcp_port_of_service name with + | [] -> Lwt.return_none + | port :: _ -> + let tls = is_tls_service name in + let svc = { Resolver.name; port; tls } in + Lwt.return (Some svc) + + let static hosts = + let service = static_service in + let rewrites = [ ("", static_resolver hosts) ] in + Resolver_lwt.init ~service ~rewrites () + + let localhost = + let hosts = Hashtbl.create 3 in + Hashtbl.add hosts "localhost" (fun ~port -> + `TCP (Ipaddr.(V4 V4.localhost), port)); + static hosts + + let vchan_resolver ~tld = + let tld_len = String.length tld in + let get_short_host uri = + let n = get_host uri in + let len = String.length n in + if len > tld_len && String.sub n (len - tld_len) tld_len = tld then + String.sub n 0 (len - tld_len) + else n + in + fun service uri -> + (* Strip the tld from the hostname *) + let remote_name = get_short_host uri in + Printf.printf "vchan_lookup: %s %s -> normalizes to %s\n%!" + (Sexplib.Sexp.to_string_hum (Resolver.sexp_of_service service)) + (Uri.to_string uri) remote_name; + Lwt.return (`Vchan_domain_socket (remote_name, service.Resolver.name)) + + module DNS = Dns_client_mirage.Make (R) (T) (C) (S) + + let dns_stub_resolver dns service uri : Conduit.endp Lwt.t = + let hostn = get_host uri in + let port = get_port service uri in + (match Ipaddr.V4.of_string hostn with + | Ok addr -> Lwt.return (Ok addr) + | Error _ -> ( + match Domain_name.of_string hostn with + | Error (`Msg msg) -> Lwt.return (Error (`Msg msg)) + | Ok domain -> ( + match Domain_name.host domain with + | Error (`Msg msg) -> Lwt.return (Error (`Msg msg)) + | Ok host -> DNS.gethostbyname dns host))) + >|= function + | Error (`Msg err) -> `Unknown ("name resolution failed: " ^ err) + | Ok addr -> `TCP (Ipaddr.V4 addr, port) + + let register ?ns ?(ns_port = 53) s res = + (* DNS stub resolver *) + let nameserver = + match ns with None -> None | Some ip -> Some (`TCP, (ip, ns_port)) + in + let dns = DNS.create ?nameserver s in + let f = dns_stub_resolver dns in + Resolver_lwt.add_rewrite ~host:"" ~f res; + let service = Resolver_lwt.(service res ++ static_service) in + Resolver_lwt.set_service ~f:service res; + let vchan_tld = ".xen" in + let vchan_res = vchan_resolver ~tld:vchan_tld in + Resolver_lwt.add_rewrite ~host:vchan_tld ~f:vchan_res res + + let v ?ns ?ns_port stack = + let res = Resolver_lwt.init () in + register ?ns ?ns_port stack res; + res + + type t = Resolver_lwt.t end diff --git a/src/conduit-mirage/resolver_mirage.mli b/src/conduit-mirage/resolver_mirage.mli index 905727de..80e8d6d6 100644 --- a/src/conduit-mirage/resolver_mirage.mli +++ b/src/conduit-mirage/resolver_mirage.mli @@ -17,27 +17,26 @@ (** Functorial interface for resolving URIs to endpoints. *) -val static : (string, port:int -> Conduit.endp) Hashtbl.t -> Resolver_lwt.t -(** [static hosts] constructs a resolver that looks up any resolution requests - from the static [hosts] hashtable instead of using the system resolver. *) +module type S = sig + include Resolver_lwt.S -val localhost : Resolver_lwt.t -(** [localhost] is a static resolver that has a single entry that maps - [localhost] to [127.0.0.1], and fails on all other hostnames. *) + val static : (string, port:int -> Conduit.endp) Hashtbl.t -> t + (** [static hosts] constructs a resolver that looks up any resolution requests + from the static [hosts] hashtable instead of using the system resolver. *) -(** Provides a DNS-enabled {!Resolver_lwt} given a network stack. See {!Make}. *) -module Make_with_stack + val localhost : t + (** [localhost] is a static resolver that has a single entry that maps + [localhost] to [127.0.0.1], and fails on all other hostnames. *) +end + +(** Provides a DNS-enabled {!Resolver_lwt} given a network stack. *) +module Make (R : Mirage_random.S) (T : Mirage_time.S) (C : Mirage_clock.MCLOCK) (S : Mirage_stack.V4) : sig - include Resolver_lwt.S with type t = Resolver_lwt.t - - module R : sig - val register : - ?ns:Ipaddr.V4.t -> ?ns_port:int -> ?stack:S.t -> Resolver_lwt.t -> unit + include S - val init : ?ns:Ipaddr.V4.t -> ?ns_port:int -> ?stack:S.t -> unit -> t - (** [init ?ns ?ns_port ?stack ()] TODO *) - end + val v : ?ns:Ipaddr.V4.t -> ?ns_port:int -> S.t -> t + (** [v ?ns ?ns_port ?stack ()] TODO *) end diff --git a/tests/conduit-mirage/http-fetch/unikernel.ml b/tests/conduit-mirage/http-fetch/unikernel.ml index 09598073..21c51c2c 100644 --- a/tests/conduit-mirage/http-fetch/unikernel.ml +++ b/tests/conduit-mirage/http-fetch/unikernel.ml @@ -14,10 +14,6 @@ module Client (C : CONSOLE) (S : STACKV4) = struct module DNS = Dns_resolver_mirage.Make (OS.Time) (S) module RES = Resolver_mirage.Make (DNS) - let mk_conduit s = - let stackv4 = Conduit_mirage.stackv4 (module S) in - Conduit_mirage.with_tcp Conduit_mirage.empty stackv4 s - let start c stack _ = C.log_s c (sprintf "Resolving in 3s using DNS server %s" ns) >>= fun () -> OS.Time.sleep 3.0 >>= fun () -> diff --git a/tests/conduit-mirage/simple/dune b/tests/conduit-mirage/simple/dune index 5af32969..47792dd5 100644 --- a/tests/conduit-mirage/simple/dune +++ b/tests/conduit-mirage/simple/dune @@ -1,4 +1,4 @@ (test (name test) - (libraries conduit-mirage) + (libraries conduit-mirage tcpip.stack-socket) (package conduit-mirage)) diff --git a/tests/conduit-mirage/simple/test.ml b/tests/conduit-mirage/simple/test.ml index fdf1bcbd..28a22747 100644 --- a/tests/conduit-mirage/simple/test.ml +++ b/tests/conduit-mirage/simple/test.ml @@ -1,8 +1,18 @@ -(* this is just to test that linking works properly *) +open Lwt.Infix let client : Conduit_mirage.client = `TCP (Ipaddr.of_string_exn "127.0.0.1", 12345) let server : Conduit_mirage.server = `TCP 12345 -let _client () = Conduit_mirage.(connect empty) client -let _server () = Conduit_mirage.(listen empty) server + +module TCP = Conduit_mirage.TCP (Tcpip_stack_socket.V4) + +let tcp () = + Udpv4_socket.connect Ipaddr.V4.Prefix.global >>= fun udp -> + Tcpv4_socket.connect Ipaddr.V4.Prefix.global >>= fun tcp -> + Tcpip_stack_socket.V4.connect udp tcp + +let _client () = tcp () >>= fun t -> TCP.connect t client + +let _server () = + tcp () >>= fun t -> TCP.listen t server (fun _flow -> Lwt.return ())