From e6a191a9b92dedfbe8ff62fd4bc65424ee4d87a8 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Thu, 21 Dec 2023 13:26:28 +0100 Subject: [PATCH 01/12] mirage: avoid an assert false, properly return an error --- mirage/awa_mirage.ml | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/mirage/awa_mirage.ml b/mirage/awa_mirage.ml index b0bad82..233fc0e 100644 --- a/mirage/awa_mirage.ml +++ b/mirage/awa_mirage.ml @@ -146,12 +146,15 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = } in writev_flow t msgs >>= fun () -> drain_handshake t >>= fun id -> - (* TODO that's a bit hardcoded... *) - let ssh = match t.state with `Active t -> t | _ -> assert false in - (match Awa.Client.outgoing_request ssh ~id req with - | Error msg -> t.state <- `Error (`Msg msg) ; Lwt.return (Error (`Msg msg)) - | Ok (ssh', data) -> t.state <- `Active ssh' ; write_flow t data) >|= fun () -> - t + match t.state with + | `Active ssh -> + (match Awa.Client.outgoing_request ssh ~id req with + | Error msg -> t.state <- `Error (`Msg msg) ; Lwt.return (Error (`Msg msg)) + | Ok (ssh', data) -> t.state <- `Active ssh' ; write_flow t data) >|= fun () -> + t + | `Eof -> Lwt.return (Error (`Msg "end of file")) + | `Error e -> Lwt.return (Error e) + (* copy from awa_lwt.ml and unix references removed in favor to FLOW *) type nexus_msg = From dcc1ea031f763388b25c877ac502223ac6326171 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Thu, 21 Dec 2023 16:01:12 +0100 Subject: [PATCH 02/12] provide close and shutdown in Awa_mirage --- awa-mirage.opam | 2 +- lib/client.ml | 20 ++++++++++++++++++++ lib/client.mli | 4 ++++ mirage/awa_mirage.ml | 39 ++++++++++++++++++++++++++++++++++++--- 4 files changed, 61 insertions(+), 4 deletions(-) diff --git a/awa-mirage.opam b/awa-mirage.opam index 772db0a..eeecb67 100644 --- a/awa-mirage.opam +++ b/awa-mirage.opam @@ -22,7 +22,7 @@ depends: [ "lwt" {>= "5.3.0"} "mirage-time" {>= "2.0.0"} "duration" {>= "0.2.0"} - "mirage-flow" {>= "2.0.0"} + "mirage-flow" {>= "4.0.0"} "mirage-clock" {>= "3.0.0"} "logs" ] diff --git a/lib/client.ml b/lib/client.ml index ab2f44c..f8f3ce4 100644 --- a/lib/client.ml +++ b/lib/client.ml @@ -523,3 +523,23 @@ let outgoing_data t ?(id = 0l) data = let* c, frags = Channel.output_data c data in let t' = { t with channels = Channel.update c t.channels } in Ok (output_msgs t' frags) + +let eof ?(id = 0l) t = + match + let* () = guard (established t) "not yet established" in + let* c = guard_some (Channel.lookup id t.channels) "no such channel" in + let msg = Ssh.Msg_channel_eof c.them.id in + Ok (output_msg t msg) + with + | Error _ -> t, None + | Ok (t, msg) -> t, Some msg + +let close ?(id = 0l) t = + match + let* () = guard (established t) "not yet established" in + let* c = guard_some (Channel.lookup id t.channels) "no such channel" in + let msg = Ssh.Msg_channel_close c.them.id in + Ok (output_msg t msg) + with + | Error _ -> t, None + | Ok (t, msg) -> t, Some msg diff --git a/lib/client.mli b/lib/client.mli index 5e9ee5d..3cf73ef 100644 --- a/lib/client.mli +++ b/lib/client.mli @@ -38,3 +38,7 @@ val outgoing_request : t -> ?id:int32 -> ?want_reply:bool -> val outgoing_data : t -> ?id:int32 -> Cstruct.t -> (t * Cstruct.t list, string) result + +val eof : ?id:int32 -> t -> t * Cstruct.t option + +val close : ?id:int32 -> t -> t * Cstruct.t option diff --git a/mirage/awa_mirage.ml b/mirage/awa_mirage.ml index 233fc0e..124dfc8 100644 --- a/mirage/awa_mirage.ml +++ b/mirage/awa_mirage.ml @@ -111,9 +111,42 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = | `Active _ -> assert false let close t = - (* TODO ssh session teardown (send some protocol messages) *) - FLOW.close t.flow >|= fun () -> - t.state <- `Eof + match t.state with + | `Active ssh -> + let state, msg = Awa.Client.close ssh in + t.state <- `Active state; + (match msg with + | None -> Lwt.return (Ok ()) + | Some msg -> writev_flow t [ msg ]) >>= fun _ -> + FLOW.close t.flow >|= fun () -> + t.state <- `Eof + | _ -> Lwt.return_unit + + let shutdown t mode = + match t.state with + | `Active ssh -> + begin + let state, msg = + if mode = `read_write then + Awa.Client.close ssh + else if mode = `write then + Awa.Client.eof ssh + else + ssh, None + in + t.state <- `Active state; + (match msg with + | None -> Lwt.return (Ok ()) + | Some msg -> writev_flow t [ msg ]) >>= fun _ -> + (if mode = `read_write then + FLOW.close t.flow + else + Lwt.return_unit) >|= fun () -> + match mode with + | `read | `read_write -> t.state <- `Eof + | `write -> () + end + | _ -> Lwt.return_unit let writev t bufs = let open Lwt_result.Infix in From 15b940c68c067eb12455a7bbf6b8567c76798a5e Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Thu, 21 Dec 2023 16:08:12 +0100 Subject: [PATCH 03/12] simplify - a shutdown \`read_write is a close --- mirage/awa_mirage.ml | 45 +++++++++++++++++++++----------------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/mirage/awa_mirage.ml b/mirage/awa_mirage.ml index 124dfc8..c0ba180 100644 --- a/mirage/awa_mirage.ml +++ b/mirage/awa_mirage.ml @@ -123,30 +123,27 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = | _ -> Lwt.return_unit let shutdown t mode = - match t.state with - | `Active ssh -> - begin - let state, msg = - if mode = `read_write then - Awa.Client.close ssh - else if mode = `write then - Awa.Client.eof ssh - else - ssh, None - in - t.state <- `Active state; - (match msg with - | None -> Lwt.return (Ok ()) - | Some msg -> writev_flow t [ msg ]) >>= fun _ -> - (if mode = `read_write then - FLOW.close t.flow - else - Lwt.return_unit) >|= fun () -> - match mode with - | `read | `read_write -> t.state <- `Eof - | `write -> () - end - | _ -> Lwt.return_unit + if mode = `read_write then + close t + else + match t.state with + | `Active ssh -> + begin + let state, msg = + if mode = `write then + Awa.Client.eof ssh + else + ssh, None + in + t.state <- `Active state; + (match msg with + | None -> Lwt.return (Ok ()) + | Some msg -> writev_flow t [ msg ]) >|= fun _ -> + match mode with + | `read | `read_write -> t.state <- `Eof + | `write -> () + end + | _ -> Lwt.return_unit let writev t bufs = let open Lwt_result.Infix in From 621fa2da72769db8a2933bb7272107b05b0c38b4 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Thu, 21 Dec 2023 17:23:35 +0100 Subject: [PATCH 04/12] mirage: preserve half-closed connections, and deal with them properly --- mirage/awa_mirage.ml | 119 +++++++++++++++++++++++++++---------------- 1 file changed, 76 insertions(+), 43 deletions(-) diff --git a/mirage/awa_mirage.ml b/mirage/awa_mirage.ml index c0ba180..660ea79 100644 --- a/mirage/awa_mirage.ml +++ b/mirage/awa_mirage.ml @@ -24,12 +24,24 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = type flow = { flow : FLOW.flow ; - mutable state : [ `Active of Awa.Client.t | `Eof | `Error of error ] + mutable state : [ + | `Active of Awa.Client.t + | `Read_closed of Awa.Client.t + | `Write_closed of Awa.Client.t + | `Closed + | `Error of error ] } let write_flow t buf = FLOW.write t.flow buf >>= function | Ok () -> Lwt.return (Ok ()) + | Error `Closed -> + Log.warn (fun m -> m "error closed while writing"); + (match t.state with + | `Active ssh -> t.state <- `Write_closed ssh + | `Read_closed _ -> t.state <- `Closed + | `Write_closed _ | `Closed | `Error _ -> ()); + Lwt.return (Error (`Write `Closed)) | Error w -> Log.warn (fun m -> m "error %a while writing" F.pp_write_error w); t.state <- `Error (`Write w) ; Lwt.return (Error (`Write w)) @@ -46,24 +58,37 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = let read_react t = match t.state with - | `Eof | `Error _ -> Lwt.return (Error ()) - | `Active _ -> + | `Read_closed _ | `Closed | `Error _ -> Lwt.return (Error ()) + | `Active _ | `Write_closed _ -> FLOW.read t.flow >>= function | Error e -> Log.warn (fun m -> m "error %a while reading" F.pp_error e); t.state <- `Error (`Read e); Lwt.return (Error ()) - | Ok `Eof -> t.state <- `Eof ; Lwt.return (Error ()) + | Ok `Eof -> + t.state <- + (match t.state with + | `Active ssh -> `Read_closed ssh + | `Write_closed _ -> `Closed + | _ -> assert false); + Lwt.return (Error ()) | Ok (`Data data) -> match t.state with - | `Active ssh -> + | `Active ssh | `Write_closed ssh -> begin match Awa.Client.incoming ssh (now ()) data with | Error msg -> Log.warn (fun m -> m "error %s while processing data" msg); t.state <- `Error (`Msg msg); Lwt.return (Error ()) | Ok (ssh', out, events) -> - let state' = if List.mem `Disconnected events then `Eof else `Active ssh' in + let state' = + match List.mem `Disconnected events, t.state with + | false, `Active _ -> `Active ssh' + | false, `Write_closed _ -> `Write_closed ssh' + | true, `Active _ -> `Read_closed ssh' + | true, `Write_closed _ -> `Closed + | _ -> assert false + in t.state <- state'; writev_flow t out >>= fun _ -> Lwt.return (Ok events) @@ -74,15 +99,15 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = read_react t >>= function | Ok es -> begin match t.state, List.filter (function `Established _ -> true | _ -> false) es with - | `Eof, _ -> Lwt.return (Error (`Msg "disconnected")) + | (`Read_closed _ | `Closed), _ -> Lwt.return (Error (`Msg "disconnected")) | `Error e, _ -> Lwt.return (Error e) - | `Active _, [ `Established id ] -> Lwt.return (Ok id) - | `Active _, _ -> drain_handshake t + | (`Active _ | `Write_closed _), [ `Established id ] -> Lwt.return (Ok id) + | (`Active _ | `Write_closed _), _ -> drain_handshake t end | Error () -> match t.state with | `Error e -> Lwt.return (Error e) - | `Eof -> Lwt.return (Error (`Msg "disconnected")) - | `Active _ -> assert false + | `Closed | `Read_closed _ -> Lwt.return (Error (`Msg "disconnected")) + | `Active _ | `Write_closed _ -> assert false let rec read t = read_react t >>= function @@ -107,62 +132,68 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = end | Error () -> match t.state with | `Error e -> Lwt.return (Error e) - | `Eof -> Lwt.return (Ok `Eof) - | `Active _ -> assert false + | `Closed | `Read_closed _ -> Lwt.return (Ok `Eof) + | `Active _ | `Write_closed _ -> assert false let close t = + FLOW.close t.flow >>= fun () -> match t.state with - | `Active ssh -> + | `Active ssh | `Read_closed ssh | `Write_closed ssh -> let state, msg = Awa.Client.close ssh in t.state <- `Active state; (match msg with | None -> Lwt.return (Ok ()) - | Some msg -> writev_flow t [ msg ]) >>= fun _ -> - FLOW.close t.flow >|= fun () -> - t.state <- `Eof - | _ -> Lwt.return_unit + | Some msg -> writev_flow t [ msg ]) >|= fun _ -> + t.state <- `Closed + | `Error _ | `Closed -> Lwt.return_unit let shutdown t mode = - if mode = `read_write then - close t - else - match t.state with - | `Active ssh -> - begin - let state, msg = - if mode = `write then - Awa.Client.eof ssh - else - ssh, None - in - t.state <- `Active state; - (match msg with - | None -> Lwt.return (Ok ()) - | Some msg -> writev_flow t [ msg ]) >|= fun _ -> - match mode with - | `read | `read_write -> t.state <- `Eof - | `write -> () - end - | _ -> Lwt.return_unit + match t.state with + | `Active _ | `Read_closed _ | `Write_closed _ -> + let state, msg = + match t.state, mode with + | (`Active ssh | `Read_closed ssh), `write -> Awa.Client.eof ssh + | (`Active ssh | `Read_closed ssh | `Write_closed ssh), `read_write -> + Awa.Client.close ssh + | (`Active ssh | `Read_closed ssh | `Write_closed ssh), _ -> ssh, None + | _ -> assert false + in + (t.state <- match t.state, mode with + | _, `read_write -> `Closed + | `Active _, `read -> `Read_closed state + | `Active _, `write -> `Write_closed state + | `Read_closed _, `read -> `Read_closed state + | `Read_closed _, `write -> `Closed + | `Write_closed _, `read -> `Closed + | `Write_closed _, `write -> `Write_closed state + | _ -> assert false); + (match msg with + | None -> Lwt.return (Ok ()) + | Some msg -> writev_flow t [ msg ]) >|= fun _ -> + () + | `Error _ | `Closed -> Lwt.return_unit let writev t bufs = let open Lwt_result.Infix in match t.state with - | `Active ssh -> + | `Active ssh | `Read_closed ssh -> Lwt_list.fold_left_s (fun r data -> match r with | Error e -> Lwt.return (Error e) | Ok ssh -> match Awa.Client.outgoing_data ssh data with | Ok (ssh', datas) -> - t.state <- `Active ssh'; + t.state <- (match t.state with + | `Active _ -> `Active ssh' + | `Read_closed _ -> `Read_closed ssh' + | _ -> assert false); writev_flow t datas >|= fun () -> ssh' | Error msg -> t.state <- `Error (`Msg msg) ; Lwt.return (Error (`Msg msg))) (Ok ssh) bufs >|= fun _ -> () - | `Eof -> Lwt.return (Error `Closed) + | `Write_closed _ | `Closed -> Lwt.return (Error `Closed) | `Error e -> Lwt.return (Error (e :> write_error)) let write t buf = writev t [buf] @@ -182,7 +213,9 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = | Error msg -> t.state <- `Error (`Msg msg) ; Lwt.return (Error (`Msg msg)) | Ok (ssh', data) -> t.state <- `Active ssh' ; write_flow t data) >|= fun () -> t - | `Eof -> Lwt.return (Error (`Msg "end of file")) + | `Read_closed _ -> Lwt.return (Error (`Msg "read closed")) + | `Write_closed _ -> Lwt.return (Error (`Msg "write closed")) + | `Closed -> Lwt.return (Error (`Msg "closed")) | `Error e -> Lwt.return (Error e) From 19ff78e09da6292661c4a3447221669430d51929 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Thu, 21 Dec 2023 17:38:47 +0100 Subject: [PATCH 05/12] mirage: avoid assertions --- mirage/awa_mirage.ml | 73 +++++++++++++++++++------------------------- 1 file changed, 31 insertions(+), 42 deletions(-) diff --git a/mirage/awa_mirage.ml b/mirage/awa_mirage.ml index 660ea79..4d7f1b8 100644 --- a/mirage/awa_mirage.ml +++ b/mirage/awa_mirage.ml @@ -32,19 +32,34 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = | `Error of error ] } + let half_close state mode = + match state, mode with + | `Active ssh, `read -> `Read_closed ssh + | `Active ssh, `write -> `Write_closed ssh + | `Active _, `read_write -> `Closed + | `Read_closed ssh, `read -> `Read_closed ssh + | `Read_closed _, (`write | `read_write) -> `Closed + | `Write_closed ssh, `write -> `Write_closed ssh + | `Write_closed _, (`read | `read_write) -> `Closed + | (`Closed | `Error _) as e, (`read | `write | `read_write) -> e + + let inject_state ssh = function + | `Active _ -> `Active ssh + | `Read_closed _ -> `Read_closed ssh + | `Write_closed _ -> `Write_closed ssh + | (`Closed | `Error _) as e -> e + let write_flow t buf = FLOW.write t.flow buf >>= function | Ok () -> Lwt.return (Ok ()) | Error `Closed -> Log.warn (fun m -> m "error closed while writing"); - (match t.state with - | `Active ssh -> t.state <- `Write_closed ssh - | `Read_closed _ -> t.state <- `Closed - | `Write_closed _ | `Closed | `Error _ -> ()); + t.state <- half_close t.state `write; Lwt.return (Error (`Write `Closed)) | Error w -> Log.warn (fun m -> m "error %a while writing" F.pp_write_error w); - t.state <- `Error (`Write w) ; Lwt.return (Error (`Write w)) + t.state <- `Error (`Write w); + Lwt.return (Error (`Write w)) let writev_flow t bufs = Lwt_list.fold_left_s (fun r d -> @@ -66,11 +81,7 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = t.state <- `Error (`Read e); Lwt.return (Error ()) | Ok `Eof -> - t.state <- - (match t.state with - | `Active ssh -> `Read_closed ssh - | `Write_closed _ -> `Closed - | _ -> assert false); + t.state <- half_close t.state `read; Lwt.return (Error ()) | Ok (`Data data) -> match t.state with @@ -81,15 +92,8 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = t.state <- `Error (`Msg msg); Lwt.return (Error ()) | Ok (ssh', out, events) -> - let state' = - match List.mem `Disconnected events, t.state with - | false, `Active _ -> `Active ssh' - | false, `Write_closed _ -> `Write_closed ssh' - | true, `Active _ -> `Read_closed ssh' - | true, `Write_closed _ -> `Closed - | _ -> assert false - in - t.state <- state'; + t.state <- + inject_state ssh' (if List.mem `Disconnected events then half_close t.state `read else t.state); writev_flow t out >>= fun _ -> Lwt.return (Ok events) end @@ -106,8 +110,7 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = end | Error () -> match t.state with | `Error e -> Lwt.return (Error e) - | `Closed | `Read_closed _ -> Lwt.return (Error (`Msg "disconnected")) - | `Active _ | `Write_closed _ -> assert false + | `Closed | `Read_closed _ | `Active _ | `Write_closed _ -> Lwt.return (Error (`Msg "disconnected")) let rec read t = read_react t >>= function @@ -132,8 +135,7 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = end | Error () -> match t.state with | `Error e -> Lwt.return (Error e) - | `Closed | `Read_closed _ -> Lwt.return (Ok `Eof) - | `Active _ | `Write_closed _ -> assert false + | `Closed | `Read_closed _ | `Active _ | `Write_closed _ -> Lwt.return (Ok `Eof) let close t = FLOW.close t.flow >>= fun () -> @@ -149,24 +151,14 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = let shutdown t mode = match t.state with - | `Active _ | `Read_closed _ | `Write_closed _ -> - let state, msg = + | `Active ssh | `Read_closed ssh | `Write_closed ssh -> + let ssh, msg = match t.state, mode with | (`Active ssh | `Read_closed ssh), `write -> Awa.Client.eof ssh - | (`Active ssh | `Read_closed ssh | `Write_closed ssh), `read_write -> - Awa.Client.close ssh - | (`Active ssh | `Read_closed ssh | `Write_closed ssh), _ -> ssh, None - | _ -> assert false + | _, `read_write -> Awa.Client.close ssh + | _ -> ssh, None in - (t.state <- match t.state, mode with - | _, `read_write -> `Closed - | `Active _, `read -> `Read_closed state - | `Active _, `write -> `Write_closed state - | `Read_closed _, `read -> `Read_closed state - | `Read_closed _, `write -> `Closed - | `Write_closed _, `read -> `Closed - | `Write_closed _, `write -> `Write_closed state - | _ -> assert false); + t.state <- inject_state ssh (half_close t.state mode); (match msg with | None -> Lwt.return (Ok ()) | Some msg -> writev_flow t [ msg ]) >|= fun _ -> @@ -183,10 +175,7 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = | Ok ssh -> match Awa.Client.outgoing_data ssh data with | Ok (ssh', datas) -> - t.state <- (match t.state with - | `Active _ -> `Active ssh' - | `Read_closed _ -> `Read_closed ssh' - | _ -> assert false); + t.state <- inject_state ssh' t.state; writev_flow t datas >|= fun () -> ssh' | Error msg -> From 8b62aca7bbd3b745c0d1ee35932c0168432f4639 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Thu, 21 Dec 2023 23:26:00 +0100 Subject: [PATCH 06/12] address @reynir review - and use inject_state --- mirage/awa_mirage.ml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mirage/awa_mirage.ml b/mirage/awa_mirage.ml index 4d7f1b8..060b9f5 100644 --- a/mirage/awa_mirage.ml +++ b/mirage/awa_mirage.ml @@ -141,8 +141,8 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = FLOW.close t.flow >>= fun () -> match t.state with | `Active ssh | `Read_closed ssh | `Write_closed ssh -> - let state, msg = Awa.Client.close ssh in - t.state <- `Active state; + let ssh, msg = Awa.Client.close ssh in + t.state <- inject_state ssh t.state; (match msg with | None -> Lwt.return (Ok ()) | Some msg -> writev_flow t [ msg ]) >|= fun _ -> From a4855169b33e1bc6f2a6b2041260216fc66c252f Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Fri, 22 Dec 2023 10:38:25 +0100 Subject: [PATCH 07/12] mirage: revise close and shutdown first to the ssh teardown, then do the underlying flow teardown --- mirage/awa_mirage.ml | 44 ++++++++++++++++++++------------------------ 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/mirage/awa_mirage.ml b/mirage/awa_mirage.ml index 060b9f5..10028dd 100644 --- a/mirage/awa_mirage.ml +++ b/mirage/awa_mirage.ml @@ -138,32 +138,28 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = | `Closed | `Read_closed _ | `Active _ | `Write_closed _ -> Lwt.return (Ok `Eof) let close t = - FLOW.close t.flow >>= fun () -> - match t.state with - | `Active ssh | `Read_closed ssh | `Write_closed ssh -> - let ssh, msg = Awa.Client.close ssh in - t.state <- inject_state ssh t.state; - (match msg with - | None -> Lwt.return (Ok ()) - | Some msg -> writev_flow t [ msg ]) >|= fun _ -> - t.state <- `Closed - | `Error _ | `Closed -> Lwt.return_unit + (match t.state with + | `Active ssh | `Read_closed ssh | `Write_closed ssh -> + let ssh, msg = Awa.Client.close ssh in + t.state <- inject_state ssh t.state; + writev_flow t (Option.to_list msg) >|= fun _ -> + t.state <- `Closed + | `Error _ | `Closed -> Lwt.return_unit) >>= fun () -> + FLOW.close t.flow let shutdown t mode = - match t.state with - | `Active ssh | `Read_closed ssh | `Write_closed ssh -> - let ssh, msg = - match t.state, mode with - | (`Active ssh | `Read_closed ssh), `write -> Awa.Client.eof ssh - | _, `read_write -> Awa.Client.close ssh - | _ -> ssh, None - in - t.state <- inject_state ssh (half_close t.state mode); - (match msg with - | None -> Lwt.return (Ok ()) - | Some msg -> writev_flow t [ msg ]) >|= fun _ -> - () - | `Error _ | `Closed -> Lwt.return_unit + (match t.state with + | `Active ssh | `Read_closed ssh | `Write_closed ssh -> + let ssh, msg = + match t.state, mode with + | (`Active ssh | `Read_closed ssh), `write -> Awa.Client.eof ssh + | _, `read_write -> Awa.Client.close ssh + | _ -> ssh, None + in + t.state <- inject_state ssh (half_close t.state mode); + writev_flow t (Option.to_list msg) >|= ignore + | `Error _ | `Closed -> Lwt.return_unit) >>= fun () -> + FLOW.shutdown t.flow mode let writev t bufs = let open Lwt_result.Infix in From 557c64adada91d7bbfafb9205127a485ea51a07b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Reynir=20Bj=C3=B6rnsson?= Date: Fri, 22 Dec 2023 11:42:28 +0100 Subject: [PATCH 08/12] shutdown: don't shutdown the flow unless closed If we are in `Read_closed we may still want to read channel-close and when we are in `Write_closed we may still want to write channel-close. --- mirage/awa_mirage.ml | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/mirage/awa_mirage.ml b/mirage/awa_mirage.ml index 10028dd..736ac0b 100644 --- a/mirage/awa_mirage.ml +++ b/mirage/awa_mirage.ml @@ -148,18 +148,24 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = FLOW.close t.flow let shutdown t mode = - (match t.state with - | `Active ssh | `Read_closed ssh | `Write_closed ssh -> - let ssh, msg = - match t.state, mode with - | (`Active ssh | `Read_closed ssh), `write -> Awa.Client.eof ssh - | _, `read_write -> Awa.Client.close ssh - | _ -> ssh, None - in - t.state <- inject_state ssh (half_close t.state mode); - writev_flow t (Option.to_list msg) >|= ignore - | `Error _ | `Closed -> Lwt.return_unit) >>= fun () -> - FLOW.shutdown t.flow mode + match t.state with + | `Active ssh | `Read_closed ssh | `Write_closed ssh -> + let ssh, msg = + match t.state, mode with + | (`Active ssh | `Read_closed ssh), `write -> Awa.Client.eof ssh + | _, `read_write -> Awa.Client.close ssh + | _ -> ssh, None + in + t.state <- inject_state ssh (half_close t.state mode); + writev_flow t (Option.to_list msg) >>= fun _ -> + (* we don't [FLOW.shutdown _ mode] because we still need to read/write + channel_eof/channel_close unless both directions are closed *) + (match t.state with + | `Closed -> + (* also close on error?! *) + FLOW.close t.flow + | _ -> Lwt.return_unit) + | `Error _ | `Closed -> Lwt.return_unit let writev t bufs = let open Lwt_result.Infix in From 12040f4e92d6dcb5c69be9961dc27b2817b33de2 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Fri, 22 Dec 2023 12:04:58 +0100 Subject: [PATCH 09/12] mirage: set closed earlier in close(); also remove TODO comment --- mirage/awa_mirage.ml | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mirage/awa_mirage.ml b/mirage/awa_mirage.ml index 736ac0b..6c73788 100644 --- a/mirage/awa_mirage.ml +++ b/mirage/awa_mirage.ml @@ -142,8 +142,8 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = | `Active ssh | `Read_closed ssh | `Write_closed ssh -> let ssh, msg = Awa.Client.close ssh in t.state <- inject_state ssh t.state; - writev_flow t (Option.to_list msg) >|= fun _ -> - t.state <- `Closed + t.state <- `Closed; + writev_flow t (Option.to_list msg) >|= ignore | `Error _ | `Closed -> Lwt.return_unit) >>= fun () -> FLOW.close t.flow @@ -161,9 +161,7 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = (* we don't [FLOW.shutdown _ mode] because we still need to read/write channel_eof/channel_close unless both directions are closed *) (match t.state with - | `Closed -> - (* also close on error?! *) - FLOW.close t.flow + | `Closed -> FLOW.close t.flow | _ -> Lwt.return_unit) | `Error _ | `Closed -> Lwt.return_unit From 611d6b4558dedcd793a1543bcc85d5710dab55b6 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Fri, 22 Dec 2023 12:26:14 +0100 Subject: [PATCH 10/12] mirage: add comment about states and why errors may occur that we ignore (thanks to @dinosaure) --- mirage/awa_mirage.ml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/mirage/awa_mirage.ml b/mirage/awa_mirage.ml index 6c73788..7221bf9 100644 --- a/mirage/awa_mirage.ml +++ b/mirage/awa_mirage.ml @@ -22,6 +22,20 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = | #Mirage_flow.write_error as e -> Mirage_flow.pp_write_error ppf e | #error as e -> pp_error ppf e + (* this is the flow of a ssh-client. be aware that we're only using a single + channel. + + the state `Read_closed is set (a) when a TCP.read returned `Eof, + and (b) when the application did a shutdown `read (or `read_write). + the state `Write_closed is set (a) when a TCP.write returned `Closed, + and (b) when the application did a shutdown `write (or `read_write). + + If we're in `Write_closed, and do a shutdown `read, we'll end up in + `Closed, and attempt to (a) send a SSH_MSG_CHANNEL_CLOSE and (b) TCP.close. + This may fail, since on the TCP layer, the connection may have already be + half-closed (or fully closed) in the write direction. We ignore this error + from writev below in close. + *) type flow = { flow : FLOW.flow ; mutable state : [ @@ -143,6 +157,7 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = let ssh, msg = Awa.Client.close ssh in t.state <- inject_state ssh t.state; t.state <- `Closed; + (* as outlined above, this may fail since the TCP flow may already be (half-)closed *) writev_flow t (Option.to_list msg) >|= ignore | `Error _ | `Closed -> Lwt.return_unit) >>= fun () -> FLOW.close t.flow @@ -157,6 +172,7 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = | _ -> ssh, None in t.state <- inject_state ssh (half_close t.state mode); + (* as outlined above, this may fail since the TCP flow may already be (half-)closed *) writev_flow t (Option.to_list msg) >>= fun _ -> (* we don't [FLOW.shutdown _ mode] because we still need to read/write channel_eof/channel_close unless both directions are closed *) From e22d3ce1907c6673e8f9860f57a480d8aad507c5 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Wed, 7 Feb 2024 13:50:26 +0100 Subject: [PATCH 11/12] minor tweaks --- mirage/awa_mirage.ml | 23 +++++++++++------------ mirage/awa_mirage.mli | 6 ++---- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/mirage/awa_mirage.ml b/mirage/awa_mirage.ml index 7221bf9..7901f18 100644 --- a/mirage/awa_mirage.ml +++ b/mirage/awa_mirage.ml @@ -5,7 +5,6 @@ module Log = (val Logs.src_log src : Logs.LOG) module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = struct - module FLOW = F module MCLOCK = M type error = [ `Msg of string @@ -37,7 +36,7 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = from writev below in close. *) type flow = { - flow : FLOW.flow ; + flow : F.flow ; mutable state : [ | `Active of Awa.Client.t | `Read_closed of Awa.Client.t @@ -64,8 +63,8 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = | (`Closed | `Error _) as e -> e let write_flow t buf = - FLOW.write t.flow buf >>= function - | Ok () -> Lwt.return (Ok ()) + F.write t.flow buf >>= function + | Ok _ as o -> Lwt.return o | Error `Closed -> Log.warn (fun m -> m "error closed while writing"); t.state <- half_close t.state `write; @@ -78,7 +77,7 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = let writev_flow t bufs = Lwt_list.fold_left_s (fun r d -> match r with - | Error e -> Lwt.return (Error e) + | Error _ as e -> Lwt.return e | Ok () -> write_flow t d) (Ok ()) bufs @@ -89,7 +88,7 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = match t.state with | `Read_closed _ | `Closed | `Error _ -> Lwt.return (Error ()) | `Active _ | `Write_closed _ -> - FLOW.read t.flow >>= function + F.read t.flow >>= function | Error e -> Log.warn (fun m -> m "error %a while reading" F.pp_error e); t.state <- `Error (`Read e); @@ -160,7 +159,7 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = (* as outlined above, this may fail since the TCP flow may already be (half-)closed *) writev_flow t (Option.to_list msg) >|= ignore | `Error _ | `Closed -> Lwt.return_unit) >>= fun () -> - FLOW.close t.flow + F.close t.flow let shutdown t mode = match t.state with @@ -177,7 +176,7 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = (* we don't [FLOW.shutdown _ mode] because we still need to read/write channel_eof/channel_close unless both directions are closed *) (match t.state with - | `Closed -> FLOW.close t.flow + | `Closed -> F.close t.flow | _ -> Lwt.return_unit) | `Error _ | `Closed -> Lwt.return_unit @@ -266,10 +265,10 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = let send_msg flow server msg = wrapr (Awa.Server.output_msg server msg) >>= fun (server, msg_buf) -> - FLOW.write flow msg_buf >>= function + F.write flow msg_buf >>= function | Ok () -> Lwt.return server | Error w -> - Log.err (fun m -> m "error %a while writing" FLOW.pp_write_error w); + Log.err (fun m -> m "error %a while writing" F.pp_write_error w); Lwt.return server let rec send_msgs fd server = function @@ -280,9 +279,9 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = | [] -> Lwt.return server let net_read flow = - FLOW.read flow >>= function + F.read flow >>= function | Error e -> - Log.err (fun m -> m "read error %a" FLOW.pp_error e); + Log.err (fun m -> m "read error %a" F.pp_error e); Lwt.return Net_eof | Ok `Eof -> Lwt.return Net_eof diff --git a/mirage/awa_mirage.mli b/mirage/awa_mirage.mli index 7ada427..51aef1d 100644 --- a/mirage/awa_mirage.mli +++ b/mirage/awa_mirage.mli @@ -3,8 +3,6 @@ (** SSH module given a flow *) module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) : sig - module FLOW : Mirage_flow.S - (** possible errors: incoming alert, processing failure, or a problem in the underlying flow. *) type error = [ `Msg of string @@ -24,7 +22,7 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) : sends the channel request. *) val client_of_flow : ?authenticator:Awa.Keys.authenticator -> user:string -> [ `Pubkey of Awa.Hostkey.priv | `Password of string ] -> - Awa.Ssh.channel_request -> FLOW.flow -> (flow, error) result Lwt.t + Awa.Ssh.channel_request -> F.flow -> (flow, error) result Lwt.t type t @@ -64,4 +62,4 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) : {b NOTE}: Even if the [ssh_channel_handler] is fulfilled, [spawn_server] continues to handle SSH channels. Only [stop] can really stop the internal SSH channels handler. *) -end with module FLOW = F +end From bc35b1b66617a5ea5775832a9c5b3b7f4ec8b4aa Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Thu, 8 Feb 2024 12:13:31 +0100 Subject: [PATCH 12/12] shutdown: if in closed/error state, call close on the underlying flow nevertheless --- mirage/awa_mirage.ml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mirage/awa_mirage.ml b/mirage/awa_mirage.ml index 7901f18..b5e44b5 100644 --- a/mirage/awa_mirage.ml +++ b/mirage/awa_mirage.ml @@ -178,7 +178,8 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = (match t.state with | `Closed -> F.close t.flow | _ -> Lwt.return_unit) - | `Error _ | `Closed -> Lwt.return_unit + | `Error _ | `Closed -> + F.close t.flow let writev t bufs = let open Lwt_result.Infix in