diff --git a/mirage-conduit.opam b/mirage-conduit.opam index 339b2b92..2a073f7c 100644 --- a/mirage-conduit.opam +++ b/mirage-conduit.opam @@ -14,7 +14,7 @@ depends: [ "mirage-stack-lwt" {>= "1.3.0"} "mirage-time-lwt" {>= "1.1.0"} "mirage-flow-lwt" {>= "1.2.0"} - "mirage-dns" {>= "3.0.0"} + "dns-client" {>= "4.0.0"} "conduit-lwt" "vchan" {>= "3.0.0"} "xenstore" diff --git a/mirage/conduit_mirage.ml b/mirage/conduit_mirage.ml index 37b8f478..4e0cdd59 100644 --- a/mirage/conduit_mirage.ml +++ b/mirage/conduit_mirage.ml @@ -327,19 +327,18 @@ let rec server (e:Conduit.endp): server Lwt.t = match e with | `TLS (x, y) -> server y >>= fun s -> tls_server x s | `Unknown s -> err_unknown s -module Context (T: Mirage_time_lwt.S) (S: Mirage_stack_lwt.V4) = struct +module Context (R: Mirage_random.C) (T: Mirage_time_lwt.S) (S: Mirage_stack_lwt.V4) = struct type t = Resolver_lwt.t * conduit - module DNS = Dns_resolver_mirage.Make(T)(S) - module RES = Resolver_mirage.Make(DNS) + module RES = Resolver_mirage.Make_with_stack(R)(T)(S) let conduit = empty let stackv4 = stackv4 (module S: Mirage_stack_lwt.V4 with type t = S.t) let create ?(tls=false) stack = let res = Resolver_lwt.init () in - RES.register ~stack res; + RES.R.register ~stack res; with_tcp conduit stackv4 stack >>= fun conduit -> if tls then with_tls conduit >|= fun conduit -> diff --git a/mirage/conduit_mirage.mli b/mirage/conduit_mirage.mli index 2d7798a4..5a9af307 100644 --- a/mirage/conduit_mirage.mli +++ b/mirage/conduit_mirage.mli @@ -131,7 +131,7 @@ end include S (** {2 Context for MirageOS conduit resolvers} *) -module Context (T: Mirage_time_lwt.S) (S: Mirage_stack_lwt.V4): sig +module Context (R: Mirage_random.C) (T: Mirage_time_lwt.S) (S: Mirage_stack_lwt.V4): sig type t = Resolver_lwt.t * conduit (** The type for contexts of conduit resolvers. *) diff --git a/mirage/dune b/mirage/dune index bdb0615a..38b103eb 100644 --- a/mirage/dune +++ b/mirage/dune @@ -6,5 +6,5 @@ (wrapped false) (optional) (libraries conduit conduit-lwt mirage-stack-lwt mirage-time-lwt - mirage-flow-lwt mirage-dns ipaddr-sexp - vchan tls tls.mirage xenstore.client)) + mirage-random mirage-flow-lwt dns-client.mirage ipaddr-sexp + vchan tls tls.mirage xenstore.client uri.services)) diff --git a/mirage/resolver_mirage.ml b/mirage/resolver_mirage.ml index 8bd91975..48b6bf91 100644 --- a/mirage/resolver_mirage.ml +++ b/mirage/resolver_mirage.ml @@ -63,63 +63,50 @@ let localhost = (fun ~port -> `TCP (Ipaddr.(V4 V4.localhost), port)); static hosts +module Make_with_stack (R: Mirage_random.C) (T: Mirage_time_lwt.S) (S: Mirage_stack_lwt.V4) = struct + include Resolver_lwt -module type S = sig - module DNS : Dns_resolver_mirage.S - val default_ns : Ipaddr.V4.t - val vchan_resolver : tld:string -> Resolver_lwt.rewrite_fn - val dns_stub_resolver: - ?ns:Ipaddr.V4.t -> ?ns_port:int -> DNS.t -> Resolver_lwt.rewrite_fn - val register: - ?ns:Ipaddr.V4.t -> ?ns_port:int -> ?stack:DNS.stack -> - Resolver_lwt.t -> unit - val init: - ?ns:Ipaddr.V4.t -> ?ns_port:int -> ?stack:DNS.stack -> unit -> Resolver_lwt.t -end - -module Make(DNS:Dns_resolver_mirage.S) = struct - module DNS = DNS - - let vchan_resolver ~tld = - let tld_len = String.length tld in - let get_short_host uri = - let n = get_host uri in - let len = String.length n in - if len > tld_len && (String.sub n (len-tld_len) tld_len = tld) then - String.sub n 0 (len-tld_len) - else - n - in - fun service uri -> - (* Strip the tld from the hostname *) - let remote_name = get_short_host uri in - Printf.printf "vchan_lookup: %s %s -> normalizes to %s\n%!" - (Sexplib.Sexp.to_string_hum (Resolver.sexp_of_service service)) - (Uri.to_string uri) remote_name; - Lwt.return (`Vchan_domain_socket (remote_name, service.Resolver.name)) - - let default_ns = Ipaddr.V4.of_string_exn "8.8.8.8" - - let dns_stub_resolver ?(ns=default_ns) ?(ns_port=53) dns service uri - : Conduit.endp Lwt.t = - let host = get_host uri in - let port = get_port service uri in - (match Ipaddr.of_string host with - | Error _ -> DNS.gethostbyname ~server:ns ~dns_port:ns_port dns host - | Ok addr -> Lwt.return [addr]) >>= fun res -> - List.filter (function Ipaddr.V4 _ -> true | _ -> false) res - |> function - | [] -> Lwt.return (`Unknown ("name resolution failed")) - | addr::_ -> Lwt.return (`TCP (addr,port)) - - let register ?(ns=default_ns) ?(ns_port=53) ?stack res = + module R = struct + let vchan_resolver ~tld = + let tld_len = String.length tld in + let get_short_host uri = + let n = get_host uri in + let len = String.length n in + if len > tld_len && (String.sub n (len-tld_len) tld_len = tld) then + String.sub n 0 (len-tld_len) + else + n + in + fun service uri -> + (* Strip the tld from the hostname *) + let remote_name = get_short_host uri in + Printf.printf "vchan_lookup: %s %s -> normalizes to %s\n%!" + (Sexplib.Sexp.to_string_hum (Resolver.sexp_of_service service)) + (Uri.to_string uri) remote_name; + Lwt.return (`Vchan_domain_socket (remote_name, service.Resolver.name)) + + module DNS = Dns_client_mirage.Make(R)(S) + + let dns_stub_resolver dns service uri : Conduit.endp Lwt.t = + let hostn = get_host uri in + let port = get_port service uri in + (match Ipaddr.V4.of_string hostn with + | Ok addr -> Lwt.return (Ok addr) + | Error _ -> + let hostname = Domain_name.(host_exn (of_string_exn hostn)) in + DNS.gethostbyname dns hostname) >|= function + | Error (`Msg err) -> `Unknown ("name resolution failed: " ^ err) + | Ok addr -> `TCP (Ipaddr.V4 addr, port) + + let register ?ns ?(ns_port = 53) ?stack res = begin match stack with - | Some s -> - (* DNS stub resolver *) - let dns = DNS.create s in - let f = dns_stub_resolver ~ns ~ns_port dns in - Resolver_lwt.add_rewrite ~host:"" ~f res - | None -> () + | Some s -> + (* DNS stub resolver *) + let nameserver = match ns with None -> None | Some ip -> Some (`TCP, (ip, ns_port)) in + let dns = DNS.create ?nameserver s in + let f = dns_stub_resolver dns in + Resolver_lwt.add_rewrite ~host:"" ~f res + | None -> () end; let service = Resolver_lwt.(service res ++ static_service) in Resolver_lwt.set_service ~f:service res; @@ -127,13 +114,9 @@ module Make(DNS:Dns_resolver_mirage.S) = struct let vchan_res = vchan_resolver ~tld:vchan_tld in Resolver_lwt.add_rewrite ~host:vchan_tld ~f:vchan_res res - let init ?ns ?ns_port ?stack () = - let res = Resolver_lwt.init () in - register ?ns ?ns_port ?stack res; - res -end - -module Make_with_stack (T: Mirage_time_lwt.S) (S: Mirage_stack_lwt.V4) = struct - module R = Make(Dns_resolver_mirage.Make(T)(S)) - include Resolver_lwt + let init ?ns ?ns_port ?stack () = + let res = Resolver_lwt.init () in + register ?ns ?ns_port ?stack res; + res + end end diff --git a/mirage/resolver_mirage.mli b/mirage/resolver_mirage.mli index e7f833c3..b9cb6eab 100644 --- a/mirage/resolver_mirage.mli +++ b/mirage/resolver_mirage.mli @@ -26,40 +26,17 @@ val static : (string, (port:int -> Conduit.endp)) Hashtbl.t -> Resolver_lwt.t maps [localhost] to [127.0.0.1], and fails on all other hostnames. *) val localhost : Resolver_lwt.t -(** Module allowing to build a {!Resolver_lwt} than can perform DNS lookups. *) -module type S = sig - module DNS : Dns_resolver_mirage.S - - (** Default resolver to use, which is [8.8.8.8] (Google DNS). *) - val default_ns : Ipaddr.V4.t - - val vchan_resolver : tld:string -> Resolver_lwt.rewrite_fn - - (** [dns_stub_resolver ?ns ?dns_port dns] will return a resolver that uses - the stub resolver [ns] on port [ns_port] to resolve URIs via - the [dns] network interface. *) - val dns_stub_resolver: - ?ns:Ipaddr.V4.t -> ?ns_port:int -> DNS.t -> Resolver_lwt.rewrite_fn - - (** [register ?ns ?ns_port ?stack res] TODO *) - val register: - ?ns:Ipaddr.V4.t -> ?ns_port:int -> ?stack:DNS.stack -> - Resolver_lwt.t -> unit - - (** [init ?ns ?ns_port ?stack ()] TODO *) - val init: - ?ns:Ipaddr.V4.t -> ?ns_port:int -> ?stack:DNS.stack -> unit -> Resolver_lwt.t -end - -(** Given a DNS resolver {{:https://github.com/mirage/ocaml-dns}implementation}, - provide a {!Resolver_lwt} that can perform DNS lookups to return - endpoints. *) -module Make(DNS:Dns_resolver_mirage.S) : S with module DNS = DNS - (** Provides a DNS-enabled {!Resolver_lwt} given a network stack. See {!Make}. *) -module Make_with_stack (T: Mirage_time_lwt.S) (S: Mirage_stack_lwt.V4) : sig +module Make_with_stack (R: Mirage_random.C) (T: Mirage_time_lwt.S) (S: Mirage_stack_lwt.V4) : sig include Resolver_lwt.S with type t = Resolver_lwt.t - module R : S with type DNS.stack = S.t + + module R : sig + val register : ?ns:Ipaddr.V4.t -> ?ns_port:int -> ?stack:S.t -> Resolver_lwt.t -> unit + + (** [init ?ns ?ns_port ?stack ()] TODO *) + val init: + ?ns:Ipaddr.V4.t -> ?ns_port:int -> ?stack:S.t -> unit -> t + end end