Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt to mirage-flow 4 API #70

Merged
merged 12 commits into from
Feb 8, 2024
2 changes: 1 addition & 1 deletion awa-mirage.opam
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ depends: [
"lwt" {>= "5.3.0"}
"mirage-time" {>= "2.0.0"}
"duration" {>= "0.2.0"}
"mirage-flow" {>= "2.0.0"}
"mirage-flow" {>= "4.0.0"}
"mirage-clock" {>= "3.0.0"}
"logs"
]
Expand Down
20 changes: 20 additions & 0 deletions lib/client.ml
Original file line number Diff line number Diff line change
Expand Up @@ -523,3 +523,23 @@ let outgoing_data t ?(id = 0l) data =
let* c, frags = Channel.output_data c data in
let t' = { t with channels = Channel.update c t.channels } in
Ok (output_msgs t' frags)

let eof ?(id = 0l) t =
match
let* () = guard (established t) "not yet established" in
let* c = guard_some (Channel.lookup id t.channels) "no such channel" in
let msg = Ssh.Msg_channel_eof c.them.id in
Ok (output_msg t msg)
with
| Error _ -> t, None
| Ok (t, msg) -> t, Some msg

let close ?(id = 0l) t =
match
let* () = guard (established t) "not yet established" in
let* c = guard_some (Channel.lookup id t.channels) "no such channel" in
let msg = Ssh.Msg_channel_close c.them.id in
Ok (output_msg t msg)
with
| Error _ -> t, None
| Ok (t, msg) -> t, Some msg
4 changes: 4 additions & 0 deletions lib/client.mli
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,7 @@ val outgoing_request : t -> ?id:int32 -> ?want_reply:bool ->

val outgoing_data : t -> ?id:int32 -> Cstruct.t ->
(t * Cstruct.t list, string) result

val eof : ?id:int32 -> t -> t * Cstruct.t option

val close : ?id:int32 -> t -> t * Cstruct.t option
125 changes: 98 additions & 27 deletions mirage/awa_mirage.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,58 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
| #Mirage_flow.write_error as e -> Mirage_flow.pp_write_error ppf e
| #error as e -> pp_error ppf e

(* this is the flow of a ssh-client. be aware that we're only using a single
channel.

the state `Read_closed is set (a) when a TCP.read returned `Eof,
and (b) when the application did a shutdown `read (or `read_write).
the state `Write_closed is set (a) when a TCP.write returned `Closed,
and (b) when the application did a shutdown `write (or `read_write).

If we're in `Write_closed, and do a shutdown `read, we'll end up in
`Closed, and attempt to (a) send a SSH_MSG_CHANNEL_CLOSE and (b) TCP.close.
This may fail, since on the TCP layer, the connection may have already be
half-closed (or fully closed) in the write direction. We ignore this error
from writev below in close.
*)
type flow = {
flow : FLOW.flow ;
mutable state : [ `Active of Awa.Client.t | `Eof | `Error of error ]
mutable state : [
| `Active of Awa.Client.t
| `Read_closed of Awa.Client.t
| `Write_closed of Awa.Client.t
| `Closed
| `Error of error ]
}

let half_close state mode =
match state, mode with
| `Active ssh, `read -> `Read_closed ssh
| `Active ssh, `write -> `Write_closed ssh
| `Active _, `read_write -> `Closed
| `Read_closed ssh, `read -> `Read_closed ssh
| `Read_closed _, (`write | `read_write) -> `Closed
| `Write_closed ssh, `write -> `Write_closed ssh
| `Write_closed _, (`read | `read_write) -> `Closed
| (`Closed | `Error _) as e, (`read | `write | `read_write) -> e

let inject_state ssh = function
| `Active _ -> `Active ssh
| `Read_closed _ -> `Read_closed ssh
| `Write_closed _ -> `Write_closed ssh
| (`Closed | `Error _) as e -> e

let write_flow t buf =
FLOW.write t.flow buf >>= function
| Ok () -> Lwt.return (Ok ())
| Error `Closed ->
Log.warn (fun m -> m "error closed while writing");
t.state <- half_close t.state `write;
Lwt.return (Error (`Write `Closed))
| Error w ->
Log.warn (fun m -> m "error %a while writing" F.pp_write_error w);
t.state <- `Error (`Write w) ; Lwt.return (Error (`Write w))
t.state <- `Error (`Write w);
Lwt.return (Error (`Write w))

let writev_flow t bufs =
Lwt_list.fold_left_s (fun r d ->
Expand All @@ -46,25 +87,27 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =

let read_react t =
match t.state with
| `Eof | `Error _ -> Lwt.return (Error ())
| `Active _ ->
| `Read_closed _ | `Closed | `Error _ -> Lwt.return (Error ())
| `Active _ | `Write_closed _ ->
FLOW.read t.flow >>= function
| Error e ->
Log.warn (fun m -> m "error %a while reading" F.pp_error e);
t.state <- `Error (`Read e);
Lwt.return (Error ())
| Ok `Eof -> t.state <- `Eof ; Lwt.return (Error ())
| Ok `Eof ->
t.state <- half_close t.state `read;
Lwt.return (Error ())
| Ok (`Data data) ->
match t.state with
| `Active ssh ->
| `Active ssh | `Write_closed ssh ->
begin match Awa.Client.incoming ssh (now ()) data with
| Error msg ->
Log.warn (fun m -> m "error %s while processing data" msg);
t.state <- `Error (`Msg msg);
Lwt.return (Error ())
| Ok (ssh', out, events) ->
let state' = if List.mem `Disconnected events then `Eof else `Active ssh' in
t.state <- state';
t.state <-
inject_state ssh' (if List.mem `Disconnected events then half_close t.state `read else t.state);
writev_flow t out >>= fun _ ->
Lwt.return (Ok events)
end
Expand All @@ -74,15 +117,14 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
read_react t >>= function
| Ok es ->
begin match t.state, List.filter (function `Established _ -> true | _ -> false) es with
| `Eof, _ -> Lwt.return (Error (`Msg "disconnected"))
| (`Read_closed _ | `Closed), _ -> Lwt.return (Error (`Msg "disconnected"))
| `Error e, _ -> Lwt.return (Error e)
| `Active _, [ `Established id ] -> Lwt.return (Ok id)
| `Active _, _ -> drain_handshake t
| (`Active _ | `Write_closed _), [ `Established id ] -> Lwt.return (Ok id)
| (`Active _ | `Write_closed _), _ -> drain_handshake t
end
| Error () -> match t.state with
| `Error e -> Lwt.return (Error e)
| `Eof -> Lwt.return (Error (`Msg "disconnected"))
| `Active _ -> assert false
| `Closed | `Read_closed _ | `Active _ | `Write_closed _ -> Lwt.return (Error (`Msg "disconnected"))

let rec read t =
read_react t >>= function
Expand All @@ -107,32 +149,56 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
end
| Error () -> match t.state with
| `Error e -> Lwt.return (Error e)
| `Eof -> Lwt.return (Ok `Eof)
| `Active _ -> assert false
| `Closed | `Read_closed _ | `Active _ | `Write_closed _ -> Lwt.return (Ok `Eof)

let close t =
(* TODO ssh session teardown (send some protocol messages) *)
FLOW.close t.flow >|= fun () ->
t.state <- `Eof
(match t.state with
| `Active ssh | `Read_closed ssh | `Write_closed ssh ->
let ssh, msg = Awa.Client.close ssh in
t.state <- inject_state ssh t.state;
t.state <- `Closed;
(* as outlined above, this may fail since the TCP flow may already be (half-)closed *)
writev_flow t (Option.to_list msg) >|= ignore
| `Error _ | `Closed -> Lwt.return_unit) >>= fun () ->
FLOW.close t.flow
hannesm marked this conversation as resolved.
Show resolved Hide resolved

let shutdown t mode =
match t.state with
| `Active ssh | `Read_closed ssh | `Write_closed ssh ->
let ssh, msg =
match t.state, mode with
| (`Active ssh | `Read_closed ssh), `write -> Awa.Client.eof ssh
| _, `read_write -> Awa.Client.close ssh
| _ -> ssh, None
in
t.state <- inject_state ssh (half_close t.state mode);
(* as outlined above, this may fail since the TCP flow may already be (half-)closed *)
writev_flow t (Option.to_list msg) >>= fun _ ->
(* we don't [FLOW.shutdown _ mode] because we still need to read/write
channel_eof/channel_close unless both directions are closed *)
(match t.state with
| `Closed -> FLOW.close t.flow
| _ -> Lwt.return_unit)
| `Error _ | `Closed -> Lwt.return_unit

let writev t bufs =
let open Lwt_result.Infix in
match t.state with
| `Active ssh ->
| `Active ssh | `Read_closed ssh ->
Lwt_list.fold_left_s (fun r data ->
match r with
| Error e -> Lwt.return (Error e)
| Ok ssh ->
match Awa.Client.outgoing_data ssh data with
| Ok (ssh', datas) ->
t.state <- `Active ssh';
t.state <- inject_state ssh' t.state;
writev_flow t datas >|= fun () ->
ssh'
| Error msg ->
t.state <- `Error (`Msg msg) ;
Lwt.return (Error (`Msg msg)))
(Ok ssh) bufs >|= fun _ -> ()
| `Eof -> Lwt.return (Error `Closed)
| `Write_closed _ | `Closed -> Lwt.return (Error `Closed)
| `Error e -> Lwt.return (Error (e :> write_error))

let write t buf = writev t [buf]
Expand All @@ -146,12 +212,17 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
} in
writev_flow t msgs >>= fun () ->
drain_handshake t >>= fun id ->
(* TODO that's a bit hardcoded... *)
let ssh = match t.state with `Active t -> t | _ -> assert false in
(match Awa.Client.outgoing_request ssh ~id req with
| Error msg -> t.state <- `Error (`Msg msg) ; Lwt.return (Error (`Msg msg))
| Ok (ssh', data) -> t.state <- `Active ssh' ; write_flow t data) >|= fun () ->
t
match t.state with
| `Active ssh ->
(match Awa.Client.outgoing_request ssh ~id req with
| Error msg -> t.state <- `Error (`Msg msg) ; Lwt.return (Error (`Msg msg))
| Ok (ssh', data) -> t.state <- `Active ssh' ; write_flow t data) >|= fun () ->
t
| `Read_closed _ -> Lwt.return (Error (`Msg "read closed"))
| `Write_closed _ -> Lwt.return (Error (`Msg "write closed"))
| `Closed -> Lwt.return (Error (`Msg "closed"))
| `Error e -> Lwt.return (Error e)


(* copy from awa_lwt.ml and unix references removed in favor to FLOW *)
type nexus_msg =
Expand Down