Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt to mirage-flow 4 API #70

Merged
merged 12 commits into from
Feb 8, 2024
2 changes: 1 addition & 1 deletion awa-mirage.opam
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ depends: [
"lwt" {>= "5.3.0"}
"mirage-time" {>= "2.0.0"}
"duration" {>= "0.2.0"}
"mirage-flow" {>= "2.0.0"}
"mirage-flow" {>= "4.0.0"}
"mirage-clock" {>= "3.0.0"}
"logs"
]
Expand Down
20 changes: 20 additions & 0 deletions lib/client.ml
Original file line number Diff line number Diff line change
Expand Up @@ -523,3 +523,23 @@ let outgoing_data t ?(id = 0l) data =
let* c, frags = Channel.output_data c data in
let t' = { t with channels = Channel.update c t.channels } in
Ok (output_msgs t' frags)

let eof ?(id = 0l) t =
match
let* () = guard (established t) "not yet established" in
let* c = guard_some (Channel.lookup id t.channels) "no such channel" in
let msg = Ssh.Msg_channel_eof c.them.id in
Ok (output_msg t msg)
with
| Error _ -> t, None
| Ok (t, msg) -> t, Some msg

let close ?(id = 0l) t =
match
let* () = guard (established t) "not yet established" in
let* c = guard_some (Channel.lookup id t.channels) "no such channel" in
let msg = Ssh.Msg_channel_close c.them.id in
Ok (output_msg t msg)
with
| Error _ -> t, None
| Ok (t, msg) -> t, Some msg
4 changes: 4 additions & 0 deletions lib/client.mli
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,7 @@ val outgoing_request : t -> ?id:int32 -> ?want_reply:bool ->

val outgoing_data : t -> ?id:int32 -> Cstruct.t ->
(t * Cstruct.t list, string) result

val eof : ?id:int32 -> t -> t * Cstruct.t option

val close : ?id:int32 -> t -> t * Cstruct.t option
111 changes: 84 additions & 27 deletions mirage/awa_mirage.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,42 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =

type flow = {
flow : FLOW.flow ;
mutable state : [ `Active of Awa.Client.t | `Eof | `Error of error ]
mutable state : [
| `Active of Awa.Client.t
| `Read_closed of Awa.Client.t
| `Write_closed of Awa.Client.t
| `Closed
| `Error of error ]
}

let half_close state mode =
match state, mode with
| `Active ssh, `read -> `Read_closed ssh
| `Active ssh, `write -> `Write_closed ssh
| `Active _, `read_write -> `Closed
| `Read_closed ssh, `read -> `Read_closed ssh
| `Read_closed _, (`write | `read_write) -> `Closed
| `Write_closed ssh, `write -> `Write_closed ssh
| `Write_closed _, (`read | `read_write) -> `Closed
| (`Closed | `Error _) as e, (`read | `write | `read_write) -> e

let inject_state ssh = function
| `Active _ -> `Active ssh
| `Read_closed _ -> `Read_closed ssh
| `Write_closed _ -> `Write_closed ssh
| (`Closed | `Error _) as e -> e

let write_flow t buf =
FLOW.write t.flow buf >>= function
| Ok () -> Lwt.return (Ok ())
| Error `Closed ->
Log.warn (fun m -> m "error closed while writing");
t.state <- half_close t.state `write;
Lwt.return (Error (`Write `Closed))
| Error w ->
Log.warn (fun m -> m "error %a while writing" F.pp_write_error w);
t.state <- `Error (`Write w) ; Lwt.return (Error (`Write w))
t.state <- `Error (`Write w);
Lwt.return (Error (`Write w))

let writev_flow t bufs =
Lwt_list.fold_left_s (fun r d ->
Expand All @@ -46,25 +73,27 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =

let read_react t =
match t.state with
| `Eof | `Error _ -> Lwt.return (Error ())
| `Active _ ->
| `Read_closed _ | `Closed | `Error _ -> Lwt.return (Error ())
| `Active _ | `Write_closed _ ->
FLOW.read t.flow >>= function
| Error e ->
Log.warn (fun m -> m "error %a while reading" F.pp_error e);
t.state <- `Error (`Read e);
Lwt.return (Error ())
| Ok `Eof -> t.state <- `Eof ; Lwt.return (Error ())
| Ok `Eof ->
t.state <- half_close t.state `read;
Lwt.return (Error ())
| Ok (`Data data) ->
match t.state with
| `Active ssh ->
| `Active ssh | `Write_closed ssh ->
begin match Awa.Client.incoming ssh (now ()) data with
| Error msg ->
Log.warn (fun m -> m "error %s while processing data" msg);
t.state <- `Error (`Msg msg);
Lwt.return (Error ())
| Ok (ssh', out, events) ->
let state' = if List.mem `Disconnected events then `Eof else `Active ssh' in
t.state <- state';
t.state <-
inject_state ssh' (if List.mem `Disconnected events then half_close t.state `read else t.state);
writev_flow t out >>= fun _ ->
Lwt.return (Ok events)
end
Expand All @@ -74,15 +103,14 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
read_react t >>= function
| Ok es ->
begin match t.state, List.filter (function `Established _ -> true | _ -> false) es with
| `Eof, _ -> Lwt.return (Error (`Msg "disconnected"))
| (`Read_closed _ | `Closed), _ -> Lwt.return (Error (`Msg "disconnected"))
| `Error e, _ -> Lwt.return (Error e)
| `Active _, [ `Established id ] -> Lwt.return (Ok id)
| `Active _, _ -> drain_handshake t
| (`Active _ | `Write_closed _), [ `Established id ] -> Lwt.return (Ok id)
| (`Active _ | `Write_closed _), _ -> drain_handshake t
end
| Error () -> match t.state with
| `Error e -> Lwt.return (Error e)
| `Eof -> Lwt.return (Error (`Msg "disconnected"))
| `Active _ -> assert false
| `Closed | `Read_closed _ | `Active _ | `Write_closed _ -> Lwt.return (Error (`Msg "disconnected"))

let rec read t =
read_react t >>= function
Expand All @@ -107,32 +135,56 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
end
| Error () -> match t.state with
| `Error e -> Lwt.return (Error e)
| `Eof -> Lwt.return (Ok `Eof)
| `Active _ -> assert false
| `Closed | `Read_closed _ | `Active _ | `Write_closed _ -> Lwt.return (Ok `Eof)

let close t =
(* TODO ssh session teardown (send some protocol messages) *)
FLOW.close t.flow >|= fun () ->
t.state <- `Eof
(match t.state with
| `Active ssh | `Read_closed ssh | `Write_closed ssh ->
let ssh, msg = Awa.Client.close ssh in
t.state <- inject_state ssh t.state;
writev_flow t (Option.to_list msg) >|= fun _ ->
t.state <- `Closed
hannesm marked this conversation as resolved.
Show resolved Hide resolved
| `Error _ | `Closed -> Lwt.return_unit) >>= fun () ->
FLOW.close t.flow
hannesm marked this conversation as resolved.
Show resolved Hide resolved

let shutdown t mode =
match t.state with
| `Active ssh | `Read_closed ssh | `Write_closed ssh ->
let ssh, msg =
match t.state, mode with
| (`Active ssh | `Read_closed ssh), `write -> Awa.Client.eof ssh
| _, `read_write -> Awa.Client.close ssh
| _ -> ssh, None
in
t.state <- inject_state ssh (half_close t.state mode);
writev_flow t (Option.to_list msg) >>= fun _ ->
(* we don't [FLOW.shutdown _ mode] because we still need to read/write
channel_eof/channel_close unless both directions are closed *)
(match t.state with
| `Closed ->
(* also close on error?! *)
hannesm marked this conversation as resolved.
Show resolved Hide resolved
FLOW.close t.flow
| _ -> Lwt.return_unit)
| `Error _ | `Closed -> Lwt.return_unit

let writev t bufs =
let open Lwt_result.Infix in
match t.state with
| `Active ssh ->
| `Active ssh | `Read_closed ssh ->
Lwt_list.fold_left_s (fun r data ->
match r with
| Error e -> Lwt.return (Error e)
| Ok ssh ->
match Awa.Client.outgoing_data ssh data with
| Ok (ssh', datas) ->
t.state <- `Active ssh';
t.state <- inject_state ssh' t.state;
writev_flow t datas >|= fun () ->
ssh'
| Error msg ->
t.state <- `Error (`Msg msg) ;
Lwt.return (Error (`Msg msg)))
(Ok ssh) bufs >|= fun _ -> ()
| `Eof -> Lwt.return (Error `Closed)
| `Write_closed _ | `Closed -> Lwt.return (Error `Closed)
| `Error e -> Lwt.return (Error (e :> write_error))

let write t buf = writev t [buf]
Expand All @@ -146,12 +198,17 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
} in
writev_flow t msgs >>= fun () ->
drain_handshake t >>= fun id ->
(* TODO that's a bit hardcoded... *)
let ssh = match t.state with `Active t -> t | _ -> assert false in
(match Awa.Client.outgoing_request ssh ~id req with
| Error msg -> t.state <- `Error (`Msg msg) ; Lwt.return (Error (`Msg msg))
| Ok (ssh', data) -> t.state <- `Active ssh' ; write_flow t data) >|= fun () ->
t
match t.state with
| `Active ssh ->
(match Awa.Client.outgoing_request ssh ~id req with
| Error msg -> t.state <- `Error (`Msg msg) ; Lwt.return (Error (`Msg msg))
| Ok (ssh', data) -> t.state <- `Active ssh' ; write_flow t data) >|= fun () ->
t
| `Read_closed _ -> Lwt.return (Error (`Msg "read closed"))
| `Write_closed _ -> Lwt.return (Error (`Msg "write closed"))
| `Closed -> Lwt.return (Error (`Msg "closed"))
| `Error e -> Lwt.return (Error e)


(* copy from awa_lwt.ml and unix references removed in favor to FLOW *)
type nexus_msg =
Expand Down