Skip to content

Commit

Permalink
use udns
Browse files Browse the repository at this point in the history
  • Loading branch information
hannesm committed Aug 14, 2019
1 parent 05ddf13 commit 61ed231
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 104 deletions.
2 changes: 1 addition & 1 deletion mirage-conduit.opam
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ depends: [
"mirage-stack-lwt" {>= "1.3.0"}
"mirage-time-lwt" {>= "1.1.0"}
"mirage-flow-lwt" {>= "1.2.0"}
"mirage-dns" {>= "3.0.0"}
"dns-mirage-client"
"conduit-lwt"
"vchan" {>= "3.0.0"}
"xenstore"
Expand Down
7 changes: 3 additions & 4 deletions mirage/conduit_mirage.ml
Original file line number Diff line number Diff line change
Expand Up @@ -327,19 +327,18 @@ let rec server (e:Conduit.endp): server Lwt.t = match e with
| `TLS (x, y) -> server y >>= fun s -> tls_server x s
| `Unknown s -> err_unknown s

module Context (T: Mirage_time_lwt.S) (S: Mirage_stack_lwt.V4) = struct
module Context (R: Mirage_random.C) (T: Mirage_time_lwt.S) (S: Mirage_stack_lwt.V4) = struct

type t = Resolver_lwt.t * conduit

module DNS = Dns_resolver_mirage.Make(T)(S)
module RES = Resolver_mirage.Make(DNS)
module RES = Resolver_mirage.Make_with_stack(R)(T)(S)

let conduit = empty
let stackv4 = stackv4 (module S: Mirage_stack_lwt.V4 with type t = S.t)

let create ?(tls=false) stack =
let res = Resolver_lwt.init () in
RES.register ~stack res;
RES.R.register ~stack res;
with_tcp conduit stackv4 stack >>= fun conduit ->
if tls then
with_tls conduit >|= fun conduit ->
Expand Down
2 changes: 1 addition & 1 deletion mirage/conduit_mirage.mli
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ end
include S

(** {2 Context for MirageOS conduit resolvers} *)
module Context (T: Mirage_time_lwt.S) (S: Mirage_stack_lwt.V4): sig
module Context (R: Mirage_random.C) (T: Mirage_time_lwt.S) (S: Mirage_stack_lwt.V4): sig

type t = Resolver_lwt.t * conduit
(** The type for contexts of conduit resolvers. *)
Expand Down
4 changes: 2 additions & 2 deletions mirage/dune
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
(wrapped false)
(optional)
(libraries conduit conduit-lwt mirage-stack-lwt mirage-time-lwt
mirage-flow-lwt mirage-dns ipaddr-sexp
vchan tls tls.mirage xenstore.client))
mirage-random mirage-flow-lwt dns-mirage-client ipaddr-sexp
vchan tls tls.mirage xenstore.client uri.services))
111 changes: 47 additions & 64 deletions mirage/resolver_mirage.ml
Original file line number Diff line number Diff line change
Expand Up @@ -63,77 +63,60 @@ let localhost =
(fun ~port -> `TCP (Ipaddr.(V4 V4.localhost), port));
static hosts

module Make_with_stack (R: Mirage_random.C) (T: Mirage_time_lwt.S) (S: Mirage_stack_lwt.V4) = struct
include Resolver_lwt

module type S = sig
module DNS : Dns_resolver_mirage.S
val default_ns : Ipaddr.V4.t
val vchan_resolver : tld:string -> Resolver_lwt.rewrite_fn
val dns_stub_resolver:
?ns:Ipaddr.V4.t -> ?ns_port:int -> DNS.t -> Resolver_lwt.rewrite_fn
val register:
?ns:Ipaddr.V4.t -> ?ns_port:int -> ?stack:DNS.stack ->
Resolver_lwt.t -> unit
val init:
?ns:Ipaddr.V4.t -> ?ns_port:int -> ?stack:DNS.stack -> unit -> Resolver_lwt.t
end

module Make(DNS:Dns_resolver_mirage.S) = struct
module DNS = DNS

let vchan_resolver ~tld =
let tld_len = String.length tld in
let get_short_host uri =
let n = get_host uri in
let len = String.length n in
if len > tld_len && (String.sub n (len-tld_len) tld_len = tld) then
String.sub n 0 (len-tld_len)
else
n
in
fun service uri ->
(* Strip the tld from the hostname *)
let remote_name = get_short_host uri in
Printf.printf "vchan_lookup: %s %s -> normalizes to %s\n%!"
(Sexplib.Sexp.to_string_hum (Resolver.sexp_of_service service))
(Uri.to_string uri) remote_name;
Lwt.return (`Vchan_domain_socket (remote_name, service.Resolver.name))

let default_ns = Ipaddr.V4.of_string_exn "8.8.8.8"

let dns_stub_resolver ?(ns=default_ns) ?(ns_port=53) dns service uri
: Conduit.endp Lwt.t =
let host = get_host uri in
let port = get_port service uri in
(match Ipaddr.of_string host with
| Error _ -> DNS.gethostbyname ~server:ns ~dns_port:ns_port dns host
| Ok addr -> Lwt.return [addr]) >>= fun res ->
List.filter (function Ipaddr.V4 _ -> true | _ -> false) res
|> function
| [] -> Lwt.return (`Unknown ("name resolution failed"))
| addr::_ -> Lwt.return (`TCP (addr,port))

let register ?(ns=default_ns) ?(ns_port=53) ?stack res =
module R = struct
let vchan_resolver ~tld =
let tld_len = String.length tld in
let get_short_host uri =
let n = get_host uri in
let len = String.length n in
if len > tld_len && (String.sub n (len-tld_len) tld_len = tld) then
String.sub n 0 (len-tld_len)
else
n
in
fun service uri ->
(* Strip the tld from the hostname *)
let remote_name = get_short_host uri in
Printf.printf "vchan_lookup: %s %s -> normalizes to %s\n%!"
(Sexplib.Sexp.to_string_hum (Resolver.sexp_of_service service))
(Uri.to_string uri) remote_name;
Lwt.return (`Vchan_domain_socket (remote_name, service.Resolver.name))

module DNS = Dns_mirage_client.Make(R)(S)

let dns_stub_resolver dns service uri : Conduit.endp Lwt.t =
let hostn = get_host uri in
let port = get_port service uri in
(match Ipaddr.V4.of_string hostn with
| Ok addr -> Lwt.return (Ok addr)
| Error _ ->
let hostname = Domain_name.(host_exn (of_string_exn hostn)) in
DNS.gethostbyname dns hostname) >|= function
| Error (`Msg err) -> `Unknown ("name resolution failed: " ^ err)
| Ok addr -> `TCP (Ipaddr.V4 addr, port)

let register ?ns ?(ns_port = 53) ?stack res =
begin match stack with
| Some s ->
(* DNS stub resolver *)
let dns = DNS.create s in
let f = dns_stub_resolver ~ns ~ns_port dns in
Resolver_lwt.add_rewrite ~host:"" ~f res
| None -> ()
| Some s ->
(* DNS stub resolver *)
let nameserver = match ns with None -> None | Some ip -> Some (`TCP, (ip, ns_port)) in
let dns = DNS.create ?nameserver s in
let f = dns_stub_resolver dns in
Resolver_lwt.add_rewrite ~host:"" ~f res
| None -> ()
end;
let service = Resolver_lwt.(service res ++ static_service) in
Resolver_lwt.set_service ~f:service res;
let vchan_tld = ".xen" in
let vchan_res = vchan_resolver ~tld:vchan_tld in
Resolver_lwt.add_rewrite ~host:vchan_tld ~f:vchan_res res

let init ?ns ?ns_port ?stack () =
let res = Resolver_lwt.init () in
register ?ns ?ns_port ?stack res;
res
end

module Make_with_stack (T: Mirage_time_lwt.S) (S: Mirage_stack_lwt.V4) = struct
module R = Make(Dns_resolver_mirage.Make(T)(S))
include Resolver_lwt
let init ?ns ?ns_port ?stack () =
let res = Resolver_lwt.init () in
register ?ns ?ns_port ?stack res;
res
end
end
41 changes: 9 additions & 32 deletions mirage/resolver_mirage.mli
Original file line number Diff line number Diff line change
Expand Up @@ -26,40 +26,17 @@ val static : (string, (port:int -> Conduit.endp)) Hashtbl.t -> Resolver_lwt.t
maps [localhost] to [127.0.0.1], and fails on all other hostnames. *)
val localhost : Resolver_lwt.t

(** Module allowing to build a {!Resolver_lwt} than can perform DNS lookups. *)
module type S = sig
module DNS : Dns_resolver_mirage.S

(** Default resolver to use, which is [8.8.8.8] (Google DNS). *)
val default_ns : Ipaddr.V4.t

val vchan_resolver : tld:string -> Resolver_lwt.rewrite_fn

(** [dns_stub_resolver ?ns ?dns_port dns] will return a resolver that uses
the stub resolver [ns] on port [ns_port] to resolve URIs via
the [dns] network interface. *)
val dns_stub_resolver:
?ns:Ipaddr.V4.t -> ?ns_port:int -> DNS.t -> Resolver_lwt.rewrite_fn

(** [register ?ns ?ns_port ?stack res] TODO *)
val register:
?ns:Ipaddr.V4.t -> ?ns_port:int -> ?stack:DNS.stack ->
Resolver_lwt.t -> unit

(** [init ?ns ?ns_port ?stack ()] TODO *)
val init:
?ns:Ipaddr.V4.t -> ?ns_port:int -> ?stack:DNS.stack -> unit -> Resolver_lwt.t
end

(** Given a DNS resolver {{:https://github.com/mirage/ocaml-dns}implementation},
provide a {!Resolver_lwt} that can perform DNS lookups to return
endpoints. *)
module Make(DNS:Dns_resolver_mirage.S) : S with module DNS = DNS

(** Provides a DNS-enabled {!Resolver_lwt} given a network stack.
See {!Make}.
*)
module Make_with_stack (T: Mirage_time_lwt.S) (S: Mirage_stack_lwt.V4) : sig
module Make_with_stack (R: Mirage_random.C) (T: Mirage_time_lwt.S) (S: Mirage_stack_lwt.V4) : sig
include Resolver_lwt.S with type t = Resolver_lwt.t
module R : S with type DNS.stack = S.t

module R : sig
val register : ?ns:Ipaddr.V4.t -> ?ns_port:int -> ?stack:S.t -> Resolver_lwt.t -> unit

(** [init ?ns ?ns_port ?stack ()] TODO *)
val init:
?ns:Ipaddr.V4.t -> ?ns_port:int -> ?stack:S.t -> unit -> t
end
end

0 comments on commit 61ed231

Please sign in to comment.