Skip to content

Commit

Permalink
Don't use cstruct - use string/bytes!
Browse files Browse the repository at this point in the history
Co-authored-by: Hannes Mehnert <[email protected]>
Co-authored-by: Reynir Björnsson <[email protected]>
  • Loading branch information
3 people committed Apr 4, 2024
1 parent fe92b0a commit d59bfc2
Show file tree
Hide file tree
Showing 12 changed files with 591 additions and 624 deletions.
3 changes: 1 addition & 2 deletions app/dune
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
dns-client-lwt
mirage-crypto-rng-lwt
mtime.clock.os
tuntap
cstruct-lwt))
tuntap))

(executable
(name miragevpn_client_notun)
Expand Down
51 changes: 28 additions & 23 deletions app/miragevpn_client_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,20 @@ let open_tun config { Miragevpn.cidr; gateway } :
let safe_close fd =
Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit)

let rec write_to_fd fd data =
if Cstruct.length data = 0 then Lwt_result.return ()
else
Lwt.catch
(fun () ->
Lwt_cstruct.write fd data >|= Cstruct.shift data >>= write_to_fd fd)
(fun e ->
safe_close fd >>= fun () ->
Lwt_result.lift (error_msgf "TCP write error %s" (Printexc.to_string e)))
let write_to_fd fd data =
let rec write_to_fd fd data off =
if String.length data = off then Lwt_result.return ()
else
Lwt.catch
(fun () ->
Lwt_unix.write_string fd data off (String.length data - off)
>>= fun written -> write_to_fd fd data (off + written))
(fun e ->
safe_close fd >>= fun () ->
Lwt_result.lift
(error_msgf "TCP write error %s" (Printexc.to_string e)))
in
write_to_fd fd data 0

let write_multiple_to_fd fd bufs =
Lwt_list.fold_left_s
Expand All @@ -89,8 +94,9 @@ let write_multiple_to_fd fd bufs =
let write_udp fd data =
Lwt.catch
(fun () ->
let len = Cstruct.length data in
Lwt_unix.send fd (Cstruct.to_bytes data) 0 len [] >|= fun sent ->
let len = String.length data in
(* Lwt_unix.send_substring doesn't exist :( *)
Lwt_unix.send fd (Bytes.unsafe_of_string data) 0 len [] >|= fun sent ->
if sent <> len then
Logs.warn (fun m ->
m "UDP short write (length %d, written %d)" len sent);
Expand Down Expand Up @@ -122,10 +128,8 @@ let read_from_fd fd =
let buf = Bytes.create bufsize in
Lwt_unix.read fd buf 0 bufsize >>= fun count ->
if count = 0 then failwith "end of file from server"
else
let cs = Cstruct.of_bytes ~len:count buf in
Logs.debug (fun m -> m "read %d bytes" count);
Lwt.return cs)
else Logs.debug (fun m -> m "read %d bytes" count);
Lwt.return (Bytes.sub_string buf 0 count))
|> Lwt_result.map_error (fun e -> `Msg (Printexc.to_string e))

let rec reader_tcp mvar fd =
Expand All @@ -142,9 +146,8 @@ let read_udp =
fun fd ->
Lwt_result.catch (fun () ->
Lwt_unix.recvfrom fd buf 0 bufsize [] >>= fun (count, _sa) ->
let cs = Cstruct.of_bytes ~len:count buf in
Logs.debug (fun m -> m "read %d bytes" count);
Lwt.return (Some cs))
Lwt.return (Some (Bytes.sub_string buf 0 count)))
|> Lwt_result.map_error (fun e -> `Msg (Printexc.to_string e))

let rec reader_udp mvar r =
Expand Down Expand Up @@ -229,7 +232,7 @@ type conn = {
mutable peer :
[ `Udp of Lwt_unix.file_descr | `Tcp of Lwt_unix.file_descr ] option;
mutable est_switch : Lwt_switch.t;
data_mvar : Cstruct.t list Lwt_mvar.t;
data_mvar : string list Lwt_mvar.t;
est_mvar : (Miragevpn.ip_config * int, unit) result Lwt_mvar.t;
event_mvar : Miragevpn.event Lwt_mvar.t;
}
Expand Down Expand Up @@ -308,16 +311,18 @@ let send_recv conn config ip_config _mtu =
(* not using write_to_fd here because partial writes to a tun
interface are semantically different from single write()s: *)
Lwt_list.iter_p
(fun pkt -> Lwt_cstruct.write tun_fd pkt >|= ignore)
(fun pkt ->
Lwt_unix.write_string tun_fd pkt 0 (String.length pkt) >|= ignore)
pkts
>>= fun () -> process_incoming ()
in
let rec process_outgoing tun_fd =
let open Lwt_result.Infix in
let buf = Cstruct.create 1500 in
Lwt_cstruct.read tun_fd buf |> Lwt_result.ok >|= Cstruct.sub buf 0
>>= fun buf ->
match Miragevpn.outgoing conn.o_client buf with
let buf = Bytes.create 1500 in
Lwt_unix.read tun_fd buf 0 (Bytes.length buf)
|> Lwt_result.ok >|= Bytes.sub_string buf 0
>>= fun data ->
match Miragevpn.outgoing conn.o_client data with
| Error `Not_ready -> failwith "tunnel not ready, dropping data"
| Ok (s', out) ->
conn.o_client <- s';
Expand Down
126 changes: 54 additions & 72 deletions src/config.ml
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ module Conf_map = struct
key corresponds to, and its semantics if there can be any doubt.
*)
type 'a k =
| Auth : Mirage_crypto.Hash.hash k
| Auth : [< Digestif.hash' > `MD5 `SHA1 `SHA224 `SHA256 `SHA384 `SHA512 ] k
| Auth_nocache : flag k
| Auth_retry : [ `Interact | `Nointeract | `None ] k
| Auth_user_pass : (string * string) k
Expand Down Expand Up @@ -184,7 +184,7 @@ module Conf_map = struct
| Mssfix : int k
| Mute_replay_warnings : flag k
| Passtos : flag k
| Peer_fingerprint : Cstruct.t list k
| Peer_fingerprint : string list k
| Persist_key : flag k
| Persist_tun : flag k
| Ping_interval : [ `Not_configured | `Seconds of int ] k
Expand Down Expand Up @@ -225,19 +225,11 @@ module Conf_map = struct
| Route_nopull : flag k
| Script_security : int k
| Secret
: ([ `Incoming | `Outgoing ] option
* Cstruct.t
* Cstruct.t
* Cstruct.t
* Cstruct.t)
: ([ `Incoming | `Outgoing ] option * string * string * string * string)
k
| Server : Ipaddr.V4.Prefix.t k
| Tls_auth
: ([ `Incoming | `Outgoing ] option
* Cstruct.t
* Cstruct.t
* Cstruct.t
* Cstruct.t)
: ([ `Incoming | `Outgoing ] option * string * string * string * string)
k
| Tls_cert : X509.Certificate.t k
| Tls_mode : [ `Client | `Server ] k
Expand Down Expand Up @@ -299,7 +291,7 @@ module Conf_map = struct
let cert_pubkey = X509.Certificate.public_key cert in
let key_pubkey = X509.Private_key.public key in
if
Cstruct.equal
String.equal
(X509.Public_key.fingerprint cert_pubkey)
(X509.Public_key.fingerprint key_pubkey)
then Ok `Some
Expand Down Expand Up @@ -371,8 +363,8 @@ module Conf_map = struct
%a\n\
-----END OpenVPN Static key V1-----"
Fmt.(array ~sep:(any "\n") string)
(match Cstruct.concat [ a; b; c; d ] |> Hex.of_cstruct with
| `Hex h -> Array.init (256 / 16) (fun i -> String.sub h (i * 32) 32))
(let h = String.concat "" [ a; b; c; d ] |> Ohex.encode in
Array.init (256 / 16) (fun i -> String.sub h (i * 32) 32))

let pp_tls_crypt_client ppf key =
let lines = Tls_crypt.save_v1 key in
Expand Down Expand Up @@ -406,21 +398,19 @@ module Conf_map = struct
(X509.Certificate.validity cert)
X509.Host.Set.pp
(X509.Certificate.hostnames cert)
Hex.pp
(Hex.of_cstruct (X509.Certificate.fingerprint `SHA256 cert))
Hex.pp
(Hex.of_cstruct
(X509.Public_key.fingerprint ~hash:`SHA256
(X509.Certificate.public_key cert)))
(X509.Certificate.encode_pem cert |> Cstruct.to_string)
Ohex.pp
(X509.Certificate.fingerprint `SHA256 cert)
Ohex.pp
(X509.Public_key.fingerprint ~hash:`SHA256
(X509.Certificate.public_key cert))
(X509.Certificate.encode_pem cert)
in
let pp_cert cert = p () "cert [inline]\n<cert>\n%a</cert>" pp_x509 cert in
let pp_ca certs =
p () "ca [inline]\n<ca>\n%a</ca>" Fmt.(list ~sep:(any "\n") pp_x509) certs
in
let pp_x509_private_key key =
p () "key [inline]\n<key>\n%s</key>"
(X509.Private_key.encode_pem key |> Cstruct.to_string)
p () "key [inline]\n<key>\n%s</key>" (X509.Private_key.encode_pem key)
in
let pp_tls_version ppf v =
Fmt.string ppf
Expand All @@ -438,14 +428,15 @@ module Conf_map = struct
| `SHA224 -> "SHA224"
| `SHA256 -> "SHA256"
| `SHA384 -> "SHA384"
| `SHA512 -> "SHA512")
| `SHA512 -> "SHA512"
| _ -> assert false (* FIXME *))
in
let pp_cipher ppf v = Fmt.string ppf (cipher_to_string (v :> cipher)) in
let pp_fingerprint ppf fp =
for i = 0 to Cstruct.length fp - 1 do
let a, b = Hex.of_char (Cstruct.get fp i) in
Fmt.pf ppf "%c%c" a b;
if i < Cstruct.length fp - 1 then Fmt.pf ppf ":"
for i = 0 to String.length fp - 1 do
let a = String.make 1 (String.get fp i) in
Fmt.pf ppf "%a" Ohex.pp a;
if i < String.length fp - 1 then Fmt.pf ppf ":"
done
in
match (k, v) with
Expand Down Expand Up @@ -518,7 +509,7 @@ module Conf_map = struct
| Persist_tun, () -> p () "persist-tun"
| Pkcs12, p12 ->
p () "pkcs12 [inline]\n<pkcs12>\n%s\n</pkcs12>"
(Base64.encode_exn (Cstruct.to_string (X509.PKCS12.encode_der p12)))
(Base64.encode_exn (X509.PKCS12.encode_der p12))
| Port, port -> p () "port %u" port
| Proto, (ip_v, kind) ->
p () "proto %s%s%s"
Expand Down Expand Up @@ -878,26 +869,26 @@ let a_option_with_single_path name kind =
let a_ca = a_option_with_single_path "ca" `Ca

let a_ca_payload str =
match X509.Certificate.decode_pem_multiple (Cstruct.of_string str) with
match X509.Certificate.decode_pem_multiple str with
| Ok certs -> Ok (B (Ca, certs))
| Error (`Msg msg) -> Error (Fmt.str "ca: invalid certificate(s): %s" msg)

let a_cert = a_option_with_single_path "cert" `Tls_cert

let a_cert_payload str =
match X509.Certificate.decode_pem (Cstruct.of_string str) with
match X509.Certificate.decode_pem str with
| Ok cert -> Ok (B (Tls_cert, cert))
| Error (`Msg msg) -> Error (Fmt.str "cert: invalid certificate: %s" msg)

let a_key = a_option_with_single_path "key" `Tls_key

let a_key_payload str =
match X509.Private_key.decode_pem (Cstruct.of_string str) with
match X509.Private_key.decode_pem str with
| Ok key -> Ok (B (Tls_key, key))
| Error (`Msg msg) -> Error ("no key found in x509 tls-key: " ^ msg)

let a_pkcs12_payload str =
match X509.PKCS12.decode_der (Cstruct.of_string str) with
match X509.PKCS12.decode_der str with
| Error (`Msg msg) -> Error ("failed to decode PKCS12: " ^ msg)
| Ok p12 -> Ok (B (Pkcs12, p12))

Expand All @@ -910,22 +901,13 @@ let a_fingerprint =
| ('0' .. '9' | 'a' .. 'f' | 'A' .. 'F') as d -> return d
| d -> Fmt.kstr fail "invalid hex character %C" d
in
let hex_val = function
| '0' .. '9' as c -> Char.code c - Char.code '0'
| 'a' .. 'f' as c -> 10 + Char.code c - Char.code 'a'
| 'A' .. 'F' as c -> 10 + Char.code c - Char.code 'A'
| _ -> assert false
in
let byte =
hex_digit >>= fun a ->
hex_digit >>| fun b -> (hex_val a lsl 4) + hex_val b
hex_digit >>| fun b ->
String.init 2 (function 0 -> a | 1 -> b | _ -> assert false)
in
count 31 (byte <* char ':') >>= fun hd ->
byte >>| fun tl ->
let buf = Cstruct.create 32 in
List.iteri (Cstruct.set_uint8 buf) hd;
Cstruct.set_uint8 buf 31 tl;
buf
byte >>| fun tl -> Ohex.decode (String.concat "" (hd @ [ tl ]))

let a_multi_fingerprint =
skip_many (a_whitespace_or_comment *> end_of_line)
Expand Down Expand Up @@ -999,32 +981,32 @@ let a_base64_line =

let inline_payload element =
let abort s = fail ("Invalid " ^ element ^ " HMAC key: " ^ s) in
Angstrom.skip_many (a_whitespace_or_comment *> end_of_line)
*> (string "-----BEGIN OpenVPN Static key V1-----" *> a_newline
<|> abort "Missing Static key V1 -----BEGIN mark")
*> many_till
( take_while (function
| 'a' .. 'f' | 'A' .. 'F' | '0' .. '9' -> true
| _ -> false)
<* (end_of_line <|> abort "Invalid hex character")
>>= fun hex ->
try return (Cstruct.of_hex hex)
with Invalid_argument msg -> abort msg )
(string "-----END OpenVPN Static key V1-----" *> a_newline
<|> abort "Missing END mark")
( Angstrom.skip_many (a_whitespace_or_comment *> end_of_line)
*> (string "-----BEGIN OpenVPN Static key V1-----" *> a_newline
<|> abort "Missing Static key V1 -----BEGIN mark")
*> many_till
( take_while (function
| 'a' .. 'f' | 'A' .. 'F' | '0' .. '9' -> true
| _ -> false)
<* (end_of_line <|> abort "Invalid hex character")
>>= fun hex ->
try return (Ohex.decode hex) with Invalid_argument msg -> abort msg
)
(string "-----END OpenVPN Static key V1-----" *> a_newline
<|> abort "Missing END mark")
<* commit
<* (skip_many (a_newline <|> a_whitespace) *> end_of_input
<|> ( pos >>= fun i ->
abort (Fmt.str "Data after -----END mark at byte offset %d" i) ))
>>= (fun lst ->
let sz = Cstruct.lenv lst in
if 256 = sz then return lst
else
abort @@ "Wrong size (" ^ string_of_int sz
^ "); need exactly 256 bytes")
>>| Cstruct.concat
>>| fun cs ->
Cstruct.(sub cs 0 64, sub cs 64 64, sub cs 128 64, sub cs (128 + 64) 64)
>>= fun lst ->
let data = String.concat "" lst in
let sz = String.length data in
if 256 = sz then return data
else
abort @@ "Wrong size (" ^ string_of_int sz ^ "); need exactly 256 bytes"
)
>>| fun data ->
String.(sub data 0 64, sub data 64 64, sub data 128 64, sub data (128 + 64) 64)

let a_tls_crypt_v2_client_payload force_cookie =
let line = a_line not_control_char in
Expand Down Expand Up @@ -1768,11 +1750,11 @@ let eq : eq =
(fun (type x) (k : x k) (v : x) (v2 : x) ->
match (k, v, v2) with
| Secret, (dir, a, b, c, d), (dir', a', b', c', d') ->
dir = dir' && Cstruct.equal a a' && Cstruct.equal b b'
&& Cstruct.equal c c' && Cstruct.equal d d'
dir = dir' && String.equal a a' && String.equal b b'
&& String.equal c c' && String.equal d d'
| Tls_auth, (dir, a, b, c, d), (dir', a', b', c', d') ->
dir = dir' && Cstruct.equal a a' && Cstruct.equal b b'
&& Cstruct.equal c c' && Cstruct.equal d d'
dir = dir' && String.equal a a' && String.equal b b'
&& String.equal c c' && String.equal d d'
| Remote, remotes_lst, remotes_lst2 ->
List.for_all2
(fun (a, port1, proto1) (b, port2, proto2) ->
Expand Down
33 changes: 0 additions & 33 deletions src/cstruct_ext.ml

This file was deleted.

Loading

0 comments on commit d59bfc2

Please sign in to comment.