Skip to content

Commit

Permalink
use Logs for server and test_server, avoid Printf.printf
Browse files Browse the repository at this point in the history
  • Loading branch information
hannesm committed Jun 19, 2023
1 parent 92cf212 commit b0cc9ec
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 44 deletions.
27 changes: 14 additions & 13 deletions lib/server.ml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

open Util

let src = Logs.Src.create "awa.server" ~doc:"AWA server"
module Log = (val Logs.src_log src : Logs.LOG)

type event =
| Channel_exec of (int32 * string)
| Channel_subsystem of (int32 * string)
Expand Down Expand Up @@ -186,14 +189,14 @@ let rec input_userauth_request t username service auth_method =
| Ok pubkey when Hostkey.comptible_alg pubkey pkalg ->
try_probe t pubkey
| Ok _ ->
Logs.debug (fun m -> m "Client offered unsupported or incompatible signature algorithm %s"
pkalg);
Log.debug (fun m -> m "Client offered unsupported or incompatible signature algorithm %s"
pkalg);
failure t
| Error `Unsupported keytype ->
Logs.debug (fun m -> m "Client offered unsupported key type %s" keytype);
Log.debug (fun m -> m "Client offered unsupported key type %s" keytype);
failure t
| Error `Msg s ->
Logs.warn (fun m -> m "Failed to decode public key (while client offered a key): %s" s);
Log.warn (fun m -> m "Failed to decode public key (while client offered a key): %s" s);
disconnect t DISCONNECT_PROTOCOL_ERROR "public key decoding failed"
end
| Pubkey (pkalg, pubkey_raw, Some (sig_alg, signed)) -> (* Public key authentication *)
Expand All @@ -210,17 +213,17 @@ let rec input_userauth_request t username service auth_method =
try_auth t (by_pubkey username sig_alg pubkey session_id service signed t.user_db)
| Ok pubkey ->
if Hostkey.comptible_alg pubkey pkalg then
Logs.debug (fun m -> m "Client offered unsupported or incompatible signature algorithm %s"
pkalg)
Log.debug (fun m -> m "Client offered unsupported or incompatible signature algorithm %s"
pkalg)
else
Logs.debug (fun m -> m "Client offered signature using algorithm different from advertised: %s vs %s"
sig_alg pkalg);
Log.debug (fun m -> m "Client offered signature using algorithm different from advertised: %s vs %s"
sig_alg pkalg);
failure t
| Error `Unsupported keytype ->
Logs.debug (fun m -> m "Client attempted authentication with unsupported key type %s" keytype);
Log.debug (fun m -> m "Client attempted authentication with unsupported key type %s" keytype);
failure t
| Error `Msg s ->
Logs.warn (fun m -> m "Failed to decode public key (while authenticating): %s" s);
Log.warn (fun m -> m "Failed to decode public key (while authenticating): %s" s);
disconnect t DISCONNECT_PROTOCOL_ERROR "public key decoding failed"
end
| Password (password, None) -> (* Password authentication *)
Expand Down Expand Up @@ -333,6 +336,7 @@ let input_msg t msg now =
match msg with
| Msg_kexinit kex ->
let* neg = Kex.negotiate ~s:t.server_kexinit ~c:kex in
Logs.debug (fun m -> m "neg is %a" Kex.pp_negotiation neg);
let ignore_next_packet =
kex.first_kex_packet_follows &&
not (Kex.guessed_right ~s:t.server_kexinit ~c:kex)
Expand Down Expand Up @@ -367,9 +371,6 @@ let input_msg t msg now =
~e ~f ~k
in
let signature = Hostkey.sign neg.server_host_key_alg t.host_key h in
Format.printf "shared is %a signature is %a (hash %a)\n%!"
Cstruct.hexdump_pp (Mirage_crypto_pk.Z_extra.to_cstruct_be f)
Cstruct.hexdump_pp signature Cstruct.hexdump_pp h;
let session_id = match t.session_id with None -> h | Some x -> x in
let* new_keys_ctos, new_keys_stoc, key_eol =
Kex.Dh.derive_keys k h session_id neg now
Expand Down
84 changes: 54 additions & 30 deletions test/awa_test_server.ml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ module Driver = struct

let send_msg t msg =
let* server, msg_buf = Server.output_msg t.server msg in
Printf.printf ">>> %s\n%!" (Fmt.to_to_string Ssh.pp_message msg);
Logs.debug (fun m -> m ">>> %s" (Fmt.to_to_string Ssh.pp_message msg));
t.write_cb msg_buf;
Ok { t with server }

Expand All @@ -60,18 +60,18 @@ module Driver = struct
| Some (server, kexinit) -> send_msg { t with server } kexinit

let rec poll t =
Printf.printf "poll called, input buffer %d\n%!"
(Cstruct.length t.input_buffer);
Logs.info (fun m -> m "poll called, input buffer %d"
(Cstruct.length t.input_buffer));
let now = t.time_cb () in
let server = t.server in
let* server, msg, input_buffer = Server.pop_msg2 server t.input_buffer in
match msg with
| None ->
Printf.printf "no msg :/, input %d\n%!" (Cstruct.length input_buffer);
Logs.info (fun m -> m "no msg :/, input %d" (Cstruct.length input_buffer));
let input_buffer = cs_join input_buffer (t.read_cb ()) in
poll { t with server; input_buffer }
| Some msg ->
Printf.printf "<<< %s\n%!" (Fmt.to_to_string Ssh.pp_message msg);
Logs.debug (fun m -> m "<<< %a" Ssh.pp_message msg);
let* server, replies, event = Server.input_msg server msg now in
let t = { t with server; input_buffer } in
let* t = send_msgs t replies in
Expand All @@ -90,9 +90,6 @@ end

let ( let* ) = Result.bind

let printf = Printf.printf
let sprintf = Printf.sprintf

(* Driver callbacks *)
let read_cstruct fd () =
let len = Ssh.max_pkt_len in
Expand All @@ -103,7 +100,7 @@ let read_cstruct fd () =
else
let cbuf = Cstruct.create n in
Cstruct.blit_from_bytes buf 0 cbuf 0 n;
Format.printf "read %d bytes\n%!" (Cstruct.length cbuf);
Logs.debug (fun m -> m "read %u bytes" (Cstruct.length cbuf));
cbuf

let write_cstruct fd buf =
Expand All @@ -128,22 +125,26 @@ let bc t id data =
let op = List.nth args 1 in
let b = int_of_string (List.nth args 2) in
match op with
| "+" -> sprintf "%d\n" (a + b)
| "-" -> sprintf "%d\n" (a - b)
| "*" -> sprintf "%d\n" (a * b)
| "/" -> if b = 0 then "Don't be an ass !\n" else sprintf "%d\n" (a / b)
| op -> sprintf "Unknown operator %s\n" op
| "+" -> Printf.sprintf "%d\n" (a + b)
| "-" -> Printf.sprintf "%d\n" (a - b)
| "*" -> Printf.sprintf "%d\n" (a * b)
| "/" -> if b = 0 then "Don't be an ass !\n" else Printf.sprintf "%d\n" (a / b)
| op -> Printf.sprintf "Unknown operator %s\n" op
in
Driver.send_channel_data t id (Cstruct.of_string reply)

let rec serve t cmd =
let open Server in
let* t, poll_result = Driver.poll t in
match poll_result with
| Disconnected s -> Ok (printf "Disconnected: %s\n%!" s)
| Channel_eof id -> Ok (printf "Channel %ld EOF\n%!" id)
| Disconnected s ->
Logs.info (fun m -> m "Disconnected: %s" s);
Ok ()
| Channel_eof id ->
Logs.info (fun m -> m "Channel %lu EOF" id);
Ok ()
| Channel_data (id, data) ->
printf "channel data %d\n%!" (Cstruct.length data);
Logs.info (fun m -> m "channel data %d" (Cstruct.length data));
(match cmd with
| None -> serve t cmd
| Some "echo" ->
Expand All @@ -159,20 +160,21 @@ let rec serve t cmd =
| _ -> Error "Unexpected cmd")
| Channel_subsystem (id, exec) (* same as exec *)
| Channel_exec (id, exec) ->
printf "channel exec %s\n%!" exec;
Logs.info (fun m -> m "channel exec %s" exec);
begin match exec with
| "suicide" ->
let* _ = Driver.disconnect t in
Ok ()
| "ping" ->
let* t = Driver.send_channel_data t id (Cstruct.of_string "pong\n") in
let* _ = Driver.disconnect t in
Ok (printf "sent pong\n%!")
Logs.info (fun m -> m "sent pong");
Ok ()
| "echo" | "bc" as c -> serve t (Some c)
| _ ->
let m = sprintf "Unknown command %s\n%!" exec in
let* t = Driver.send_channel_data t id (Cstruct.of_string m) in
printf "%s\n%!" m;
let msg = Printf.sprintf "Unknown command %s" exec in
let* t = Driver.send_channel_data t id (Cstruct.of_string msg) in
Logs.info (fun m -> m "%s" msg);
let* t = Driver.disconnect t in
serve t cmd end
| _ -> failwith "Invalid SSH event"
Expand All @@ -189,9 +191,9 @@ let user_db =
[ foo; awa ]

let rec wait_connection priv_key listen_fd server_port =
printf "Awa server waiting connections on port %d\n%!" server_port;
Logs.info (fun m -> m "Awa server waiting connections on port %d" server_port);
let client_fd, _ = Unix.(accept listen_fd) in
printf "Client connected !\n%!";
Logs.info (fun m -> m "Client connected!");
let server, msgs = Server.make priv_key user_db in
let* t =
Driver.of_server server msgs
Expand All @@ -200,13 +202,13 @@ let rec wait_connection priv_key listen_fd server_port =
Mtime_clock.now
in
let () = match serve t None with
| Ok _ -> printf "Client finished\n%!"
| Error e -> printf "error: %s\n%!" e
| Ok () -> Logs.info (fun m -> m "Client finished")
| Error e -> Logs.warn (fun m -> m "error: %s" e)
in
Unix.close client_fd;
wait_connection priv_key listen_fd server_port

let () =
let jump () =
Mirage_crypto_rng_unix.initialize (module Mirage_crypto_rng.Fortuna);
let g = Mirage_crypto_rng.(create ~seed:(Cstruct.of_string "180586") (module Fortuna)) in
let (ec_priv,_) = Mirage_crypto_ec.Ed25519.generate ~g () in
Expand All @@ -216,6 +218,28 @@ let () =
Unix.(setsockopt listen_fd SO_REUSEADDR true);
Unix.(bind listen_fd (ADDR_INET (inet_addr_any, server_port)));
Unix.listen listen_fd 1;
match wait_connection priv_key listen_fd server_port with
| Error e -> printf "error %s\n%!" e
| Ok _ -> printf "ok\n%!\n"
Result.map_error
(fun msg -> `Msg msg)
(wait_connection priv_key listen_fd server_port)

let setup_log style_renderer level =
Fmt_tty.setup_std_outputs ?style_renderer ();
Logs.set_level level;
Logs.set_reporter (Logs_fmt.reporter ~dst:Format.std_formatter ())

open Cmdliner

let setup_log =
Term.(const setup_log
$ Fmt_cli.style_renderer ()
$ Logs_cli.level ())

let cmd =
let term =
Term.(term_result (const jump $ setup_log))
and info =
Cmd.info "awa_test_server" ~version:"%%VERSION_NUM"
in
Cmd.v info term

let () = exit (Cmd.eval cmd)
3 changes: 2 additions & 1 deletion test/dune
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
(public_name awa_test_server)
(modules awa_test_server)
(package awa)
(libraries awa mtime.clock.os cstruct-unix mirage-crypto-rng.unix))
(libraries awa mtime.clock.os cstruct-unix mirage-crypto-rng.unix
fmt.tty logs.fmt logs.cli fmt.cli))

(executable
(name awa_test_client)
Expand Down

0 comments on commit b0cc9ec

Please sign in to comment.