From f4745e05ad1d1742b46a9a5f853d50bc3126e152 Mon Sep 17 00:00:00 2001 From: mbfreder Date: Tue, 3 Oct 2023 10:26:16 -0700 Subject: [PATCH] Implemented SnapStart Runtime Hooks --- Cargo.toml | 1 + examples/basic-runtime-hooks/.gitignore | 1 + examples/basic-runtime-hooks/Cargo.toml | 19 + examples/basic-runtime-hooks/README.md | 11 + examples/basic-runtime-hooks/src/main.rs | 94 +++++ examples/http-runtime-hooks/.gitignore | 1 + examples/http-runtime-hooks/Cargo.toml | 22 ++ examples/http-runtime-hooks/README.md | 11 + examples/http-runtime-hooks/src/main.rs | 115 ++++++ lambda-http/src/lib.rs | 175 ++++++++- lambda-http/src/streaming.rs | 93 ----- lambda-runtime/Cargo.toml | 4 + lambda-runtime/src/crac.rs | 86 +++++ lambda-runtime/src/lib.rs | 454 +++++++++++------------ lambda-runtime/src/requests.rs | 51 ++- lambda-runtime/src/simulated.rs | 100 ----- 16 files changed, 785 insertions(+), 453 deletions(-) create mode 100644 examples/basic-runtime-hooks/.gitignore create mode 100644 examples/basic-runtime-hooks/Cargo.toml create mode 100644 examples/basic-runtime-hooks/README.md create mode 100644 examples/basic-runtime-hooks/src/main.rs create mode 100644 examples/http-runtime-hooks/.gitignore create mode 100644 examples/http-runtime-hooks/Cargo.toml create mode 100644 examples/http-runtime-hooks/README.md create mode 100644 examples/http-runtime-hooks/src/main.rs delete mode 100644 lambda-http/src/streaming.rs create mode 100644 lambda-runtime/src/crac.rs delete mode 100644 lambda-runtime/src/simulated.rs diff --git a/Cargo.toml b/Cargo.toml index 48bcd5db..16f57a7b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,5 @@ [workspace] +resolver = "2" members = [ "lambda-http", "lambda-integration-tests", diff --git a/examples/basic-runtime-hooks/.gitignore b/examples/basic-runtime-hooks/.gitignore new file mode 100644 index 00000000..ea8c4bf7 --- /dev/null +++ b/examples/basic-runtime-hooks/.gitignore @@ -0,0 +1 @@ +/target diff --git a/examples/basic-runtime-hooks/Cargo.toml b/examples/basic-runtime-hooks/Cargo.toml new file mode 100644 index 00000000..742e3b8e --- /dev/null +++ b/examples/basic-runtime-hooks/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "basic-runtime-hooks" +version = "0.1.0" +edition = "2021" + + +# Use cargo-edit(https://github.com/killercup/cargo-edit#installation) +# to manage dependencies. +# Running `cargo add DEPENDENCY_NAME` will +# add the latest version of a dependency to the list, +# and it will keep the alphabetic ordering for you. + +[dependencies] +lambda_runtime = { path = "../../lambda-runtime" } +serde = "1.0.136" +tokio = { version = "1", features = ["macros"] } +tracing = { version = "0.1", features = ["log"] } +tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt"] } +uuid = { version = "1.4.1", features = ["v4"]} diff --git a/examples/basic-runtime-hooks/README.md b/examples/basic-runtime-hooks/README.md new file mode 100644 index 00000000..498f8a50 --- /dev/null +++ b/examples/basic-runtime-hooks/README.md @@ -0,0 +1,11 @@ +# AWS Lambda Function example + +## Build & Deploy + +1. Install [cargo-lambda](https://github.com/cargo-lambda/cargo-lambda#installation) +2. Build the function with `cargo lambda build --release` +3. Deploy the function to AWS Lambda with `cargo lambda deploy --iam-role YOUR_ROLE` + +## Build for ARM 64 + +Build the function with `cargo lambda build --release --arm64` diff --git a/examples/basic-runtime-hooks/src/main.rs b/examples/basic-runtime-hooks/src/main.rs new file mode 100644 index 00000000..ed1ffb60 --- /dev/null +++ b/examples/basic-runtime-hooks/src/main.rs @@ -0,0 +1,94 @@ +// This example demonstrates use of shared resources such as DB connections +// or local caches that can be initialized at the start of the runtime and +// reused by subsequent lambda handler calls. +// Run it with the following input: +// { "command": "do something" } + +use lambda_runtime::{crac::Resource, service_fn, Error, LambdaEvent, Runtime}; +use serde::{Deserialize, Serialize}; +use std::cell::RefCell; +use uuid::Uuid; + +/// This is also a made-up example. Requests come into the runtime as unicode +/// strings in json format, which can map to any structure that implements `serde::Deserialize` +/// The runtime pays no attention to the contents of the request payload. +#[derive(Deserialize)] +struct Request { + command: String, +} + +/// This is a made-up example of what a response structure may look like. +/// There is no restriction on what it can be. The runtime requires responses +/// to be serialized into json. The runtime pays no attention +/// to the contents of the response payload. +#[derive(Serialize)] +struct Response { + req_id: String, + msg: String, + secret: String, +} + +struct SharedClient { + name: &'static str, + secret: RefCell, +} + +impl SharedClient { + fn new(name: &'static str, secret: String) -> Self { + Self { + name, + secret: RefCell::new(secret), + } + } + + fn response(&self, req_id: String, command: String) -> Response { + Response { + req_id, + msg: format!("Command {} executed by {}.", command, self.name), + secret: self.secret.borrow().clone(), + } + } +} + +impl Resource for SharedClient { + fn before_checkpoint(&self) -> Result<(), Error> { + // clear the secret before checkpointing + *self.secret.borrow_mut() = String::new(); + tracing::info!("in before_checkpoint: secret={:?}", self.secret.borrow()); + Ok(()) + } + fn after_restore(&self) -> Result<(), Error> { + // regenerate the secret after restoring + let secret = Uuid::new_v4().to_string(); + *self.secret.borrow_mut() = secret; + tracing::info!("in after_restore: secret={:?}", self.secret.borrow()); + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<(), Error> { + // required to enable CloudWatch error logging by the runtime + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + // disable printing the name of the module in every log line. + .with_target(false) + // disabling time is handy because CloudWatch will add the ingestion time. + .without_time() + .init(); + + let secret = Uuid::new_v4().to_string(); + let client = SharedClient::new("Shared Client 1 (perhaps a database)", secret); + let client_ref = &client; + tracing::info!("In main function: secret={:?}", client_ref.secret.borrow()); + + Runtime::new() + .register(client_ref) + .run(service_fn(move |event: LambdaEvent| async move { + tracing::info!("In handler function: secret={:?}", client_ref.secret.borrow()); + let command = event.payload.command; + Ok::(client_ref.response(event.context.request_id, command)) + })) + .await?; + Ok(()) +} diff --git a/examples/http-runtime-hooks/.gitignore b/examples/http-runtime-hooks/.gitignore new file mode 100644 index 00000000..ea8c4bf7 --- /dev/null +++ b/examples/http-runtime-hooks/.gitignore @@ -0,0 +1 @@ +/target diff --git a/examples/http-runtime-hooks/Cargo.toml b/examples/http-runtime-hooks/Cargo.toml new file mode 100644 index 00000000..07a96131 --- /dev/null +++ b/examples/http-runtime-hooks/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "http-runtime-hooks" +version = "0.1.0" +edition = "2021" + +# Starting in Rust 1.62 you can use `cargo add` to add dependencies +# to your project. +# +# If you're using an older Rust version, +# download cargo-edit(https://github.com/killercup/cargo-edit#installation) +# to install the `add` subcommand. +# +# Running `cargo add DEPENDENCY_NAME` will +# add the latest version of a dependency to the list, +# and it will keep the alphabetic ordering for you. + +[dependencies] +lambda_http = { path = "../../lambda-http" } +tokio = { version = "1", features = ["macros"] } +tracing = { version = "0.1", features = ["log"] } +tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt"] } +uuid = { version = "1.4.1", features = ["v4"]} diff --git a/examples/http-runtime-hooks/README.md b/examples/http-runtime-hooks/README.md new file mode 100644 index 00000000..498f8a50 --- /dev/null +++ b/examples/http-runtime-hooks/README.md @@ -0,0 +1,11 @@ +# AWS Lambda Function example + +## Build & Deploy + +1. Install [cargo-lambda](https://github.com/cargo-lambda/cargo-lambda#installation) +2. Build the function with `cargo lambda build --release` +3. Deploy the function to AWS Lambda with `cargo lambda deploy --iam-role YOUR_ROLE` + +## Build for ARM 64 + +Build the function with `cargo lambda build --release --arm64` diff --git a/examples/http-runtime-hooks/src/main.rs b/examples/http-runtime-hooks/src/main.rs new file mode 100644 index 00000000..e5f406f5 --- /dev/null +++ b/examples/http-runtime-hooks/src/main.rs @@ -0,0 +1,115 @@ +use lambda_http::{crac, service_fn, Body, Error, IntoResponse, Request, RequestExt, Response, Runtime}; +use std::sync::Arc; +use std::sync::RwLock; +use uuid::Uuid; + +struct SharedClient { + name: &'static str, + secret: Arc>, +} + +impl SharedClient { + fn new(name: &'static str, secret: String) -> Self { + Self { + name, + secret: Arc::new(RwLock::new(secret)), + } + } + + fn response(&self, req_id: String, first_name: &str) -> String { + format!("{}: Client ({}) invoked by {}.", req_id, self.name, first_name) + } +} + +impl crac::Resource for SharedClient { + fn before_checkpoint(&self) -> Result<(), Error> { + // clear the secret before checkpointing + { + let mut write_lock = self.secret.write().unwrap(); + *write_lock = String::new(); + } // release the write lock + + { + tracing::info!("in before_checkpoint: secret={:?}", self.secret.read().unwrap()); + } // release the read lock + + Ok(()) + } + fn after_restore(&self) -> Result<(), Error> { + // regenerate the secret after restoring + { + let mut write_lock = self.secret.write().unwrap(); + *write_lock = Uuid::new_v4().to_string(); + } // release the write lock + + { + tracing::info!("in after_restore: secret={:?}", self.secret.read().unwrap()); + } // release the read lock + + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<(), Error> { + // required to enable CloudWatch error logging by the runtime + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + // disable printing the name of the module in every log line. + .with_target(false) + // disabling time is handy because CloudWatch will add the ingestion time. + .without_time() + .init(); + + // Create the "client" and a reference to it, so that we can pass this into the handler closure below. + let secret = Uuid::new_v4().to_string(); + let shared_client = SharedClient::new("random_client_name_1", secret); + let shared_client_ref = &shared_client; + { + tracing::info!( + "In main function: secret={:?}", + shared_client_ref.secret.read().unwrap() + ); + } // release the read lock + + // Define a closure here that makes use of the shared client. + let handler_func_closure = move |event: Request| async move { + { + tracing::info!( + "In handler function: secret={:?}", + shared_client_ref.secret.read().unwrap() + ); + } // release the read lock + + Result::, Error>::Ok( + match event + .query_string_parameters_ref() + .and_then(|params| params.first("first_name")) + { + Some(first_name) => { + shared_client_ref + .response( + event + .lambda_context_ref() + .map(|ctx| ctx.request_id.clone()) + .unwrap_or_default(), + first_name, + ) + .into_response() + .await + } + None => Response::builder() + .status(400) + .body("Empty first name".into()) + .expect("failed to render response"), + }, + ) + }; + + // Pass the closure to the runtime here. + Runtime::new() + .register(shared_client_ref) + .run(service_fn(handler_func_closure)) + .await?; + Ok(()) +} diff --git a/lambda-http/src/lib.rs b/lambda-http/src/lib.rs index bc9e753d..a8402b66 100644 --- a/lambda-http/src/lib.rs +++ b/lambda-http/src/lib.rs @@ -64,9 +64,11 @@ #[macro_use] extern crate maplit; +use bytes::Bytes; pub use http::{self, Response}; -use lambda_runtime::LambdaEvent; -pub use lambda_runtime::{self, service_fn, tower, Context, Error, Service}; +pub use lambda_runtime::{ + self, crac, service_fn, tower, Context, Error, LambdaEvent, MetadataPrelude, Service, StreamResponse, +}; use request::RequestFuture; use response::ResponseFuture; @@ -89,15 +91,165 @@ use crate::{ pub use aws_lambda_events; pub use aws_lambda_events::encodings::Body; +use http::header::SET_COOKIE; use std::{ + fmt::Debug, + fmt::Display, future::Future, marker::PhantomData, pin::Pin, task::{Context as TaskContext, Poll}, }; +use tokio_stream::Stream; +use tower::{ServiceBuilder, ServiceExt}; + +/// Starts the Lambda Rust runtime and begins polling for events on the [Lambda +/// Runtime APIs](https://docs.aws.amazon.com/lambda/latest/dg/runtimes-api.html). +/// +/// This takes care of transforming the LambdaEvent into a [`Request`] and then +/// converting the result into a [`LambdaResponse`]. +pub async fn run<'a, R, S, E>(handler: S) -> Result<(), Error> +where + S: Service, + S::Future: Send + 'a, + R: IntoResponse, + E: std::fmt::Debug + std::fmt::Display, +{ + Runtime::new().register(&()).run(handler).await +} + +/// Starts the Lambda Rust runtime and stream response back [Configure Lambda +/// Streaming Response](https://docs.aws.amazon.com/lambda/latest/dg/configuration-response-streaming.html). +/// +/// This takes care of transforming the LambdaEvent into a [`Request`] and +/// accepts [`http::Response`] as response. +pub async fn run_with_streaming_response<'a, S, B, E>(handler: S) -> Result<(), Error> +where + S: Service, Error = E>, + S::Future: Send + 'a, + E: Debug + Display, + B: http_body::Body + Unpin + Send + 'static, + B::Data: Into + Send, + B::Error: Into + Send + Debug, +{ + Runtime::new().register(&()).run_with_streaming_response(handler).await +} + +/// The entry point for the lambda function using Builder pattern. +pub struct Runtime<'a, T: crac::Resource> { + inner: lambda_runtime::Runtime<'a, T>, +} + +impl<'a, T: crac::Resource> Default for Runtime<'a, T> { + fn default() -> Self { + Runtime::new() + } +} + +impl<'a, T: crac::Resource> Runtime<'a, T> { + /// Creates a new `Runtime` + pub fn new() -> Self { + Self { + inner: lambda_runtime::Runtime::new(), + } + } + + /// Registers a new crac::Resource with the `Runtime` + pub fn register(&mut self, resource: &'a T) -> &mut Self { + self.inner.register(resource); + self + } + + /// Runs the lambda function with the provided handler `Service` + pub async fn run(&self, handler: S) -> Result<(), Error> + where + S: Service, + S::Future: Send + 'a, + R: IntoResponse, + E: std::fmt::Debug + std::fmt::Display, + { + self.inner.run(Adapter::from(handler)).await + } + + /// Starts the Lambda Rust runtime and stream response back [Configure Lambda + /// Streaming Response](https://docs.aws.amazon.com/lambda/latest/dg/configuration-response-streaming.html). + /// + /// This takes care of transforming the LambdaEvent into a [`Request`] and + /// accepts [`http::Response`] as response. + pub async fn run_with_streaming_response(&self, handler: S) -> Result<(), Error> + where + S: Service, Error = E>, + S::Future: Send + 'a, + E: Debug + Display, + B: http_body::Body + Unpin + Send + 'static, + B::Data: Into + Send, + B::Error: Into + Send + Debug, + { + let svc = ServiceBuilder::new() + .map_request(|req: LambdaEvent| { + let event: Request = req.payload.into(); + event.with_lambda_context(req.context) + }) + .service(handler) + .map_response(|res| { + let (parts, body) = res.into_parts(); + + let mut prelude_headers = parts.headers; + + let cookies = prelude_headers.get_all(SET_COOKIE); + let cookies = cookies + .iter() + .map(|c| String::from_utf8_lossy(c.as_bytes()).to_string()) + .collect::>(); + + prelude_headers.remove(SET_COOKIE); -mod streaming; -pub use streaming::run_with_streaming_response; + let metadata_prelude = MetadataPrelude { + headers: prelude_headers, + status_code: parts.status, + cookies, + }; + + StreamResponse { + metadata_prelude, + stream: BodyStream { body }, + } + }); + + self.inner.run(svc).await + } +} + +/// Stream wrapper for [`http_body::Body`] +pub struct BodyStream { + /// Wrapped [`http_body::Body`] + pub body: B, +} + +impl BodyStream +where + B: http_body::Body + Unpin + Send + 'static, + B::Data: Into + Send, + B::Error: Into + Send + Debug, +{ + fn project(self: Pin<&mut Self>) -> Pin<&mut B> { + unsafe { self.map_unchecked_mut(|s| &mut s.body) } + } +} + +impl Stream for BodyStream +where + B: http_body::Body + Unpin + Send + 'static, + B::Data: Into + Send, + B::Error: Into + Send + Debug, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let body = self.project(); + body.poll_data(cx) + } +} /// Type alias for `http::Request`s with a fixed [`Body`](enum.Body.html) type pub type Request = http::Request; @@ -181,21 +333,6 @@ where } } -/// Starts the Lambda Rust runtime and begins polling for events on the [Lambda -/// Runtime APIs](https://docs.aws.amazon.com/lambda/latest/dg/runtimes-api.html). -/// -/// This takes care of transforming the LambdaEvent into a [`Request`] and then -/// converting the result into a [`LambdaResponse`]. -pub async fn run<'a, R, S, E>(handler: S) -> Result<(), Error> -where - S: Service, - S::Future: Send + 'a, - R: IntoResponse, - E: std::fmt::Debug + std::fmt::Display, -{ - lambda_runtime::run(Adapter::from(handler)).await -} - #[cfg(test)] mod test_adapter { use std::task::{Context, Poll}; diff --git a/lambda-http/src/streaming.rs b/lambda-http/src/streaming.rs deleted file mode 100644 index a59cf700..00000000 --- a/lambda-http/src/streaming.rs +++ /dev/null @@ -1,93 +0,0 @@ -use crate::http::header::SET_COOKIE; -use crate::tower::ServiceBuilder; -use crate::Request; -use crate::{request::LambdaRequest, RequestExt}; -pub use aws_lambda_events::encodings::Body as LambdaEventBody; -use bytes::Bytes; -pub use http::{self, Response}; -use http_body::Body; -pub use lambda_runtime::{ - self, service_fn, tower, tower::ServiceExt, Error, FunctionResponse, LambdaEvent, MetadataPrelude, Service, - StreamResponse, -}; -use std::fmt::{Debug, Display}; -use std::pin::Pin; -use std::task::{Context, Poll}; -use tokio_stream::Stream; - -/// Starts the Lambda Rust runtime and stream response back [Configure Lambda -/// Streaming Response](https://docs.aws.amazon.com/lambda/latest/dg/configuration-response-streaming.html). -/// -/// This takes care of transforming the LambdaEvent into a [`Request`] and -/// accepts [`http::Response`] as response. -pub async fn run_with_streaming_response<'a, S, B, E>(handler: S) -> Result<(), Error> -where - S: Service, Error = E>, - S::Future: Send + 'a, - E: Debug + Display, - B: Body + Unpin + Send + 'static, - B::Data: Into + Send, - B::Error: Into + Send + Debug, -{ - let svc = ServiceBuilder::new() - .map_request(|req: LambdaEvent| { - let event: Request = req.payload.into(); - event.with_lambda_context(req.context) - }) - .service(handler) - .map_response(|res| { - let (parts, body) = res.into_parts(); - - let mut prelude_headers = parts.headers; - - let cookies = prelude_headers.get_all(SET_COOKIE); - let cookies = cookies - .iter() - .map(|c| String::from_utf8_lossy(c.as_bytes()).to_string()) - .collect::>(); - - prelude_headers.remove(SET_COOKIE); - - let metadata_prelude = MetadataPrelude { - headers: prelude_headers, - status_code: parts.status, - cookies, - }; - - StreamResponse { - metadata_prelude, - stream: BodyStream { body }, - } - }); - - lambda_runtime::run(svc).await -} - -pub struct BodyStream { - pub(crate) body: B, -} - -impl BodyStream -where - B: Body + Unpin + Send + 'static, - B::Data: Into + Send, - B::Error: Into + Send + Debug, -{ - fn project(self: Pin<&mut Self>) -> Pin<&mut B> { - unsafe { self.map_unchecked_mut(|s| &mut s.body) } - } -} - -impl Stream for BodyStream -where - B: Body + Unpin + Send + 'static, - B::Data: Into + Send, - B::Error: Into + Send + Debug, -{ - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let body = self.project(); - body.poll_data(cx) - } -} diff --git a/lambda-runtime/Cargo.toml b/lambda-runtime/Cargo.toml index 9202b1c1..faa922ae 100644 --- a/lambda-runtime/Cargo.toml +++ b/lambda-runtime/Cargo.toml @@ -45,3 +45,7 @@ serde_path_to_error = "0.1.11" http-serde = "1.1.3" base64 = "0.20.0" http-body = "0.4" +thiserror = "1.0" + +[dev-dependencies] +httpmock = "0.6" diff --git a/lambda-runtime/src/crac.rs b/lambda-runtime/src/crac.rs new file mode 100644 index 00000000..97883a40 --- /dev/null +++ b/lambda-runtime/src/crac.rs @@ -0,0 +1,86 @@ +use crate::Error; +use thiserror::Error; + +/// A trait for receiving checkpoint/restore notifications. +/// +/// The type that is interested in receiving a checkpoint/restore notification +/// implements this trait, and the instance created from that type is registered +/// inside the Runtime's list of resources, using the Runtime's register() method. +pub trait Resource { + /// Invoked by Runtime as a notification about checkpoint (that snapshot is about to be taken) + fn before_checkpoint(&self) -> Result<(), Error> { + Ok(()) + } + /// Invoked by Runtime as a notification about restore (snapshot was restored) + fn after_restore(&self) -> Result<(), Error> { + Ok(()) + } +} + +/// Errors that can occur during checkpoint/restore hooks +#[derive(Error, Debug)] +pub enum CracError { + /// Errors occurred during before_checkpoint() hook + #[error("before checkpoint hooks errors: {0}")] + BeforeCheckpointError(String), + /// Errors occurred during after_restore() hook + #[error("after restore hooks errors: {0}")] + AfterRestoreError(String), +} + +// implement a dummy Resource for unit type '()' +impl Resource for () {} + +/// A context for CRAC resources. +pub struct Context<'a, T: Resource> { + resources: Vec<&'a T>, +} + +impl<'a, T: Resource> Default for Context<'a, T> { + fn default() -> Self { + Context::new() + } +} + +impl<'a, T: Resource> Context<'a, T> { + /// Creates a new Context. + pub fn new() -> Self { + Context { resources: Vec::new() } + } + + /// Registers a new resource. + pub fn register(&mut self, resource: &'a T) -> &mut Self { + self.resources.push(resource); + self + } + + /// Invokes before_checkpoint() on all registered resources in the reverse order of registration. + pub fn before_checkpoint(&self) -> Result<(), Error> { + let mut checkpint_errors: Vec = Vec::new(); + for resource in self.resources.iter().rev() { + let result = resource.before_checkpoint(); + if let Err(err) = result { + checkpint_errors.push(err.to_string()); + } + } + if !checkpint_errors.is_empty() { + return Err(Box::new(CracError::BeforeCheckpointError(checkpint_errors.join(", ")))); + } + Ok(()) + } + + /// Invokes after_restore() on all registered resources in the order of registration. + pub fn after_restore(&self) -> Result<(), Error> { + let mut restore_errors: Vec = Vec::new(); + for resource in &self.resources { + let result = resource.after_restore(); + if let Err(err) = result { + restore_errors.push(err.to_string()); + } + } + if !restore_errors.is_empty() { + return Err(Box::new(CracError::AfterRestoreError(restore_errors.join(", ")))); + } + Ok(()) + } +} diff --git a/lambda-runtime/src/lib.rs b/lambda-runtime/src/lib.rs index 18b1066e..7be529f0 100644 --- a/lambda-runtime/src/lib.rs +++ b/lambda-runtime/src/lib.rs @@ -9,11 +9,7 @@ //! and runs the Lambda runtime. use bytes::Bytes; use futures::FutureExt; -use hyper::{ - client::{connect::Connection, HttpConnector}, - http::Request, - Body, -}; +use hyper::{client::HttpConnector, http::Request, Body}; use lambda_runtime_api_client::Client; use serde::{Deserialize, Serialize}; use std::{ @@ -24,20 +20,21 @@ use std::{ marker::PhantomData, panic, }; -use tokio::io::{AsyncRead, AsyncWrite}; use tokio_stream::{Stream, StreamExt}; pub use tower::{self, service_fn, Service}; use tower::{util::ServiceFn, ServiceExt}; use tracing::{error, trace, Instrument}; +/// CRAC module contains the Resource trait for receiving checkpoint/restore notifications. +pub mod crac; mod deserializer; mod requests; -#[cfg(test)] -mod simulated; /// Types available to a Lambda function. mod types; -use requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEventRequest}; +use requests::{ + EventCompletionRequest, EventErrorRequest, InitErrorRequest, IntoRequest, NextEventRequest, RestoreNextRequest, +}; pub use types::{Context, FunctionResponse, IntoFunctionResponse, LambdaEvent, MetadataPrelude, StreamResponse}; /// Error type that lambdas may result in @@ -56,6 +53,8 @@ pub struct Config { pub log_stream: String, /// The name of the Amazon CloudWatch Logs group for the function. pub log_group: String, + /// The initialization type of the functin, which is 'on-demand', 'provisioned-concurrency', or 'snap-start'. + pub init_type: String, } impl Config { @@ -70,6 +69,7 @@ impl Config { version: env::var("AWS_LAMBDA_FUNCTION_VERSION").expect("Missing AWS_LAMBDA_FUNCTION_VERSION env var"), log_stream: env::var("AWS_LAMBDA_LOG_STREAM_NAME").unwrap_or_default(), log_group: env::var("AWS_LAMBDA_LOG_GROUP_NAME").unwrap_or_default(), + init_type: env::var("AWS_LAMBDA_INITIALIZATION_TYPE").unwrap_or_default(), }; Ok(conf) } @@ -84,19 +84,15 @@ where service_fn(move |req: LambdaEvent| f(req.payload, req.context)) } -struct Runtime = HttpConnector> { - client: Client, +/// The entry point for the lambda function using Builder pattern. +pub struct Runtime<'a, T: crac::Resource> { + client: Client, config: Config, + crac_context: crac::Context<'a, T>, } -impl Runtime -where - C: Service + Clone + Send + Sync + Unpin + 'static, - C::Future: Unpin + Send, - C::Error: Into>, - C::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, -{ - async fn run( +impl<'a, T: crac::Resource> Runtime<'a, T> { + async fn execute( &self, incoming: impl Stream, Error>> + Send, mut handler: F, @@ -113,6 +109,9 @@ where E: Into + Send + Debug, { let client = &self.client; + if "snap-start" == self.config.init_type { + self.on_init_complete(client).await?; + } tokio::pin!(incoming); while let Some(next_event_response) = incoming.next().await { trace!("New event arrived (run loop)"); @@ -210,15 +209,73 @@ where } Ok(()) } + + async fn on_init_complete(&self, client: &Client) -> Result<(), Error> { + let res = self.crac_context.before_checkpoint(); + if let Err(err) = res { + error!("{:?}", err); + let req = InitErrorRequest::new("runtime.BeforeCheckpointError", &err.to_string()).into_req()?; + client.call(req).await?; + std::process::exit(64) + } + + let req = RestoreNextRequest + .into_req() + .expect("Unable to build restore next requests"); + // Blocking call to RAPID /runtime/restore/next API, will return after taking snapshot. + // This will also be the 'entrypoint' when resuming from snapshots. + client.call(req).await?; + + let res = self.crac_context.after_restore(); + if let Err(err) = res { + error!("{:?}", err); + let req = InitErrorRequest::new("runtime.AfterRestoreError", &err.to_string()).into_req()?; + client.call(req).await?; + std::process::exit(64) + } + + Ok(()) + } + + /// Creates a new Lambda Rust runtime. + pub fn new() -> Self { + trace!("Loading config from env"); + let config = Config::from_env().expect("Unable to parse config from environment variables"); + let client = Client::builder().build().expect("Unable to create a runtime client"); + Runtime { + client, + config, + crac_context: crac::Context::new(), + } + } + + /// Registers a crac::Resource with the runtime. + pub fn register(&mut self, resource: &'a T) -> &mut Self { + self.crac_context.register(resource); + self + } + + /// Runs the Lambda Rust runtime. + pub async fn run(&self, handler: F) -> Result<(), Error> + where + F: Service>, + F::Future: Future>, + F::Error: fmt::Debug + fmt::Display, + A: for<'de> Deserialize<'de>, + R: IntoFunctionResponse, + B: Serialize, + S: Stream> + Unpin + Send + 'static, + D: Into + Send, + E: Into + Send + Debug, + { + let incoming = incoming(&self.client); + self.execute(incoming, handler).await + } } -fn incoming(client: &Client) -> impl Stream, Error>> + Send + '_ -where - C: Service + Clone + Send + Sync + Unpin + 'static, - >::Future: Unpin + Send, - >::Error: Into>, - >::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, -{ +fn incoming( + client: &Client, +) -> impl Stream, Error>> + Send + '_ { async_stream::stream! { loop { trace!("Waiting for next event (incoming loop)"); @@ -229,6 +286,12 @@ where } } +impl<'a, T: crac::Resource> Default for Runtime<'a, T> { + fn default() -> Self { + Runtime::new() + } +} + /// Starts the Lambda Rust runtime and begins polling for events on the [Lambda /// Runtime APIs](https://docs.aws.amazon.com/lambda/latest/dg/runtimes-api.html). /// @@ -260,14 +323,7 @@ where D: Into + Send, E: Into + Send + Debug, { - trace!("Loading config from env"); - let config = Config::from_env()?; - let client = Client::builder().build().expect("Unable to create a runtime client"); - let runtime = Runtime { client, config }; - - let client = &runtime.client; - let incoming = incoming(client); - runtime.run(incoming, handler).await + Runtime::new().register(&()).run(handler).await } fn type_name_of_val(_: T) -> &'static str { @@ -288,223 +344,136 @@ where #[cfg(test)] mod endpoint_tests { use crate::{ - incoming, - requests::{ - EventCompletionRequest, EventErrorRequest, IntoRequest, IntoResponse, NextEventRequest, NextEventResponse, - }, - simulated, + crac, incoming, + requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEventRequest}, types::Diagnostic, Error, Runtime, }; use futures::future::BoxFuture; - use http::{uri::PathAndQuery, HeaderValue, Method, Request, Response, StatusCode, Uri}; - use hyper::{server::conn::Http, service::service_fn, Body}; + use http::{HeaderValue, StatusCode, Uri}; + use httpmock::prelude::*; + use hyper::{client::HttpConnector, Body}; use lambda_runtime_api_client::Client; use serde_json::json; - use simulated::DuplexStreamWrapper; use std::{convert::TryFrom, env, marker::PhantomData}; - use tokio::{ - io::{self, AsyncRead, AsyncWrite}, - select, - sync::{self, oneshot}, - }; use tokio_stream::StreamExt; - #[cfg(test)] - async fn next_event(req: &Request) -> Result, Error> { - let path = "/2018-06-01/runtime/invocation/next"; - assert_eq!(req.method(), Method::GET); - assert_eq!(req.uri().path_and_query().unwrap(), &PathAndQuery::from_static(path)); - let body = json!({"message": "hello"}); - - let rsp = NextEventResponse { - request_id: "8476a536-e9f4-11e8-9739-2dfe598c3fcd", - deadline: 1_542_409_706_888, - arn: "arn:aws:lambda:us-east-2:123456789012:function:custom-runtime", - trace_id: "Root=1-5bef4de7-ad49b0e87f6ef6c87fc2e700;Parent=9a9197af755a6419", - body: serde_json::to_vec(&body)?, - }; - rsp.into_rsp() - } - - #[cfg(test)] - async fn complete_event(req: &Request, id: &str) -> Result, Error> { - assert_eq!(Method::POST, req.method()); - let rsp = Response::builder() - .status(StatusCode::ACCEPTED) - .body(Body::empty()) - .expect("Unable to construct response"); - - let expected = format!("/2018-06-01/runtime/invocation/{id}/response"); - assert_eq!(expected, req.uri().path()); - - Ok(rsp) - } - - #[cfg(test)] - async fn event_err(req: &Request, id: &str) -> Result, Error> { - let expected = format!("/2018-06-01/runtime/invocation/{id}/error"); - assert_eq!(expected, req.uri().path()); - - assert_eq!(req.method(), Method::POST); - let header = "lambda-runtime-function-error-type"; - let expected = "unhandled"; - assert_eq!(req.headers()[header], HeaderValue::try_from(expected)?); - - let rsp = Response::builder().status(StatusCode::ACCEPTED).body(Body::empty())?; - Ok(rsp) - } - - #[cfg(test)] - async fn handle_incoming(req: Request) -> Result, Error> { - let path: Vec<&str> = req - .uri() - .path_and_query() - .expect("PathAndQuery not found") - .as_str() - .split('/') - .collect::>(); - match path[1..] { - ["2018-06-01", "runtime", "invocation", "next"] => next_event(&req).await, - ["2018-06-01", "runtime", "invocation", id, "response"] => complete_event(&req, id).await, - ["2018-06-01", "runtime", "invocation", id, "error"] => event_err(&req, id).await, - ["2018-06-01", "runtime", "init", "error"] => unimplemented!(), - _ => unimplemented!(), - } - } - - #[cfg(test)] - async fn handle(io: I, rx: oneshot::Receiver<()>) -> Result<(), hyper::Error> - where - I: AsyncRead + AsyncWrite + Unpin + 'static, - { - let conn = Http::new().serve_connection(io, service_fn(handle_incoming)); - select! { - _ = rx => { - Ok(()) - } - res = conn => { - match res { - Ok(()) => Ok(()), - Err(e) => { - Err(e) - } - } - } - } - } - #[tokio::test] - async fn test_next_event() -> Result<(), Error> { - let base = Uri::from_static("http://localhost:9001"); - let (client, server) = io::duplex(64); - - let (tx, rx) = sync::oneshot::channel(); - let server = tokio::spawn(async { - handle(server, rx).await.expect("Unable to handle request"); + async fn test_next_event() { + // Start mock runtime api server + let rapid = MockServer::start(); + let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9"; + let endpoint = rapid.mock(|when, then| { + when.method(GET).path("/2018-06-01/runtime/invocation/next"); + then.status(StatusCode::OK.as_u16()) + .header("Lambda-Runtime-Aws-Request-Id", request_id) + .header("Lambda-Runtime-Deadline-Ms", "1542409706888") + .body("ok"); }); - let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?; - let client = Client::with(base, conn); - - let req = NextEventRequest.into_req()?; + // build the next event request and send to the mock endpoint + let base = Uri::try_from(format!("http://{}", rapid.address())).unwrap(); + let client = Client::with(base, HttpConnector::new()); + let req = NextEventRequest.into_req().unwrap(); let rsp = client.call(req).await.expect("Unable to send request"); + // Assert endpoint was called once + endpoint.assert(); + // and response has expected content assert_eq!(rsp.status(), StatusCode::OK); - let header = "lambda-runtime-deadline-ms"; - assert_eq!(rsp.headers()[header], &HeaderValue::try_from("1542409706888")?); - - // shutdown server... - tx.send(()).expect("Receiver has been dropped"); - match server.await { - Ok(_) => Ok(()), - Err(e) if e.is_panic() => Err::<(), Error>(e.into()), - Err(_) => unreachable!("This branch shouldn't be reachable"), - } + assert_eq!( + rsp.headers()["Lambda-Runtime-Aws-Request-Id"], + &HeaderValue::try_from(request_id).unwrap() + ); + assert_eq!( + rsp.headers()["Lambda-Runtime-Deadline-Ms"], + &HeaderValue::try_from("1542409706888").unwrap() + ); } #[tokio::test] - async fn test_ok_response() -> Result<(), Error> { - let (client, server) = io::duplex(64); - let (tx, rx) = sync::oneshot::channel(); - let base = Uri::from_static("http://localhost:9001"); - - let server = tokio::spawn(async { - handle(server, rx).await.expect("Unable to handle request"); + async fn test_ok_response() { + // Start mock runtime api server + let rapid = MockServer::start(); + let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9"; + let endpoint = rapid.mock(|when, then| { + when.method(POST) + .path(format!("/2018-06-01/runtime/invocation/{}/response", request_id)); + then.status(StatusCode::ACCEPTED.as_u16()); }); - let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?; - let client = Client::with(base, conn); + // build the OK response and send to the mock endpoint + let base = Uri::try_from(format!("http://{}", rapid.address())).unwrap(); + let client = Client::with(base, HttpConnector::new()); let req = EventCompletionRequest { - request_id: "156cb537-e2d4-11e8-9b34-d36013741fb9", + request_id, body: "done", _unused_b: PhantomData::<&str>, _unused_s: PhantomData::, }; - let req = req.into_req()?; + let req = req.into_req().unwrap(); - let rsp = client.call(req).await?; - assert_eq!(rsp.status(), StatusCode::ACCEPTED); + let rsp = client.call(req).await.unwrap(); - // shutdown server - tx.send(()).expect("Receiver has been dropped"); - match server.await { - Ok(_) => Ok(()), - Err(e) if e.is_panic() => Err::<(), Error>(e.into()), - Err(_) => unreachable!("This branch shouldn't be reachable"), - } + // Assert endpoint was called once + endpoint.assert(); + // and response has expected content + assert_eq!(rsp.status(), StatusCode::ACCEPTED); } #[tokio::test] - async fn test_error_response() -> Result<(), Error> { - let (client, server) = io::duplex(200); - let (tx, rx) = sync::oneshot::channel(); - let base = Uri::from_static("http://localhost:9001"); - - let server = tokio::spawn(async { - handle(server, rx).await.expect("Unable to handle request"); + async fn test_error_response() { + // Start mock runtime api server + let rapid = MockServer::start(); + let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9"; + let endpoint = rapid.mock(|when, then| { + when.method(POST) + .path(format!("/2018-06-01/runtime/invocation/{}/error", request_id)); + then.status(StatusCode::ACCEPTED.as_u16()); }); - let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?; - let client = Client::with(base, conn); + // build the ERROR response and send to the mock endpoint + let base = Uri::try_from(format!("http://{}", rapid.address())).unwrap(); + let client = Client::with(base, HttpConnector::new()); let req = EventErrorRequest { - request_id: "156cb537-e2d4-11e8-9b34-d36013741fb9", + request_id, diagnostic: Diagnostic { error_type: "InvalidEventDataError", error_message: "Error parsing event data", }, }; - let req = req.into_req()?; - let rsp = client.call(req).await?; - assert_eq!(rsp.status(), StatusCode::ACCEPTED); + let req = req.into_req().unwrap(); + let rsp = client.call(req).await.unwrap(); - // shutdown server - tx.send(()).expect("Receiver has been dropped"); - match server.await { - Ok(_) => Ok(()), - Err(e) if e.is_panic() => Err::<(), Error>(e.into()), - Err(_) => unreachable!("This branch shouldn't be reachable"), - } + // Assert endpoint was called once + endpoint.assert(); + // and response has expected content + assert_eq!(rsp.status(), StatusCode::ACCEPTED); } #[tokio::test] - async fn successful_end_to_end_run() -> Result<(), Error> { - let (client, server) = io::duplex(64); - let (tx, rx) = sync::oneshot::channel(); - let base = Uri::from_static("http://localhost:9001"); - - let server = tokio::spawn(async { - handle(server, rx).await.expect("Unable to handle request"); + async fn successful_end_to_end_run() { + // Start mock runtime api server + let rapid = MockServer::start(); + let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9"; + let next_endpoint = rapid.mock(|when, then| { + when.method(GET).path("/2018-06-01/runtime/invocation/next"); + then.status(StatusCode::OK.as_u16()) + .header("Lambda-Runtime-Aws-Request-Id", request_id) + .header("Lambda-Runtime-Deadline-Ms", "1542409706888") + .body(json!({"command": "hello"}).to_string()); + }); + let response_endpoint = rapid.mock(|when, then| { + when.method(POST) + .path(format!("/2018-06-01/runtime/invocation/{}/response", request_id)); + then.status(StatusCode::ACCEPTED.as_u16()); }); - let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?; - let client = Client::builder() - .with_endpoint(base) - .with_connector(conn) - .build() - .expect("Unable to build client"); + // build the client to the mock endpoint + let base = Uri::try_from(format!("http://{}", rapid.address())).unwrap(); + let client = Client::with(base, HttpConnector::new()); async fn func(event: crate::LambdaEvent) -> Result { let (event, _) = event.into_parts(); @@ -533,38 +502,43 @@ mod endpoint_tests { } let config = crate::Config::from_env().expect("Failed to read env vars"); - let runtime = Runtime { client, config }; + let runtime = Runtime { + client, + config, + crac_context: crac::Context::<()>::new(), + }; let client = &runtime.client; let incoming = incoming(client).take(1); - runtime.run(incoming, f).await?; - - // shutdown server - tx.send(()).expect("Receiver has been dropped"); - match server.await { - Ok(_) => Ok(()), - Err(e) if e.is_panic() => Err::<(), Error>(e.into()), - Err(_) => unreachable!("This branch shouldn't be reachable"), - } + runtime.execute(incoming, f).await.unwrap(); + + // Assert endpoints were called + next_endpoint.assert(); + response_endpoint.assert(); } - async fn run_panicking_handler(func: F) -> Result<(), Error> + async fn run_panicking_handler(func: F) where F: FnMut(crate::LambdaEvent) -> BoxFuture<'static, Result>, { - let (client, server) = io::duplex(64); - let (_tx, rx) = oneshot::channel(); - let base = Uri::from_static("http://localhost:9001"); - - let server = tokio::spawn(async { - handle(server, rx).await.expect("Unable to handle request"); + // Start mock runtime api server + let rapid = MockServer::start(); + let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9"; + let next_endpoint = rapid.mock(|when, then| { + when.method(GET).path("/2018-06-01/runtime/invocation/next"); + then.status(StatusCode::OK.as_u16()) + .header("Lambda-Runtime-Aws-Request-Id", request_id) + .header("Lambda-Runtime-Deadline-Ms", "1542409706888") + .body(json!({"command": "hello"}).to_string()); + }); + let error_endpoint = rapid.mock(|when, then| { + when.method(POST) + .path(format!("/2018-06-01/runtime/invocation/{}/error", request_id)); + then.status(StatusCode::ACCEPTED.as_u16()); }); - let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?; - let client = Client::builder() - .with_endpoint(base) - .with_connector(conn) - .build() - .expect("Unable to build client"); + // build the client to the mock endpoint + let base = Uri::try_from(format!("http://{}", rapid.address())).unwrap(); + let client = Client::with(base, HttpConnector::new()); let f = crate::service_fn(func); @@ -574,30 +548,36 @@ mod endpoint_tests { version: "1".to_string(), log_stream: "test_stream".to_string(), log_group: "test_log".to_string(), + init_type: "on-demand".to_string(), }; - let runtime = Runtime { client, config }; + let runtime = Runtime { + client, + config, + crac_context: crac::Context::<()>::new(), + }; let client = &runtime.client; let incoming = incoming(client).take(1); - runtime.run(incoming, f).await?; + runtime.execute(incoming, f).await.unwrap(); - match server.await { - Ok(_) => Ok(()), - Err(e) if e.is_panic() => Err::<(), Error>(e.into()), - Err(_) => unreachable!("This branch shouldn't be reachable"), - } + // Assert endpoints were called + next_endpoint.assert(); + error_endpoint.assert(); } #[tokio::test] - async fn panic_in_async_run() -> Result<(), Error> { + async fn panic_in_async_run() { run_panicking_handler(|_| Box::pin(async { panic!("This is intentionally here") })).await } #[tokio::test] - async fn panic_outside_async_run() -> Result<(), Error> { + async fn panic_outside_async_run() { run_panicking_handler(|_| { panic!("This is intentionally here"); }) .await } + + #[tokio::test] + async fn test_snapstart_runtime_hooks() {} } diff --git a/lambda-runtime/src/requests.rs b/lambda-runtime/src/requests.rs index 8e72fc2d..dfc77128 100644 --- a/lambda-runtime/src/requests.rs +++ b/lambda-runtime/src/requests.rs @@ -227,25 +227,45 @@ fn test_event_error_request() { } // /runtime/init/error -struct InitErrorRequest; +pub(crate) struct InitErrorRequest<'a> { + pub(crate) diagnostic: Diagnostic<'a>, +} -impl IntoRequest for InitErrorRequest { +impl<'a> InitErrorRequest<'a> { + pub(crate) fn new(error_type: &'a str, error_message: &'a str) -> InitErrorRequest<'a> { + InitErrorRequest { + diagnostic: Diagnostic { + error_type, + error_message, + }, + } + } +} + +impl<'a> IntoRequest for InitErrorRequest<'a> { fn into_req(self) -> Result, Error> { let uri = "/2018-06-01/runtime/init/error".to_string(); let uri = Uri::from_str(&uri)?; + let body = serde_json::to_vec(&self.diagnostic)?; + let body = Body::from(body); let req = build_request() .method(Method::POST) .uri(uri) .header("lambda-runtime-function-error-type", "unhandled") - .body(Body::empty())?; + .body(body)?; Ok(req) } } #[test] fn test_init_error_request() { - let req = InitErrorRequest; + let req = InitErrorRequest { + diagnostic: Diagnostic { + error_type: "runtime.InitError", + error_message: "SnapShot Runtime Hook Error", + }, + }; let req = req.into_req().unwrap(); let expected = Uri::from_static("/2018-06-01/runtime/init/error"); assert_eq!(req.method(), Method::POST); @@ -255,3 +275,26 @@ fn test_init_error_request() { None => false, }); } + +pub(crate) struct RestoreNextRequest; + +impl IntoRequest for RestoreNextRequest { + fn into_req(self) -> Result, Error> { + let req = build_request() + .method(Method::GET) + .uri(Uri::from_static("/2018-06-01/runtime/restore/next")) + .body(Body::empty())?; + Ok(req) + } +} +#[test] +fn test_restore_next_event_request() { + let req = RestoreNextRequest; + let req = req.into_req().unwrap(); + assert_eq!(req.method(), Method::GET); + assert_eq!(req.uri(), &Uri::from_static("/2018-06-01/runtime/restore/next")); + assert!(match req.headers().get("User-Agent") { + Some(header) => header.to_str().unwrap().starts_with("aws-lambda-rust/"), + None => false, + }); +} diff --git a/lambda-runtime/src/simulated.rs b/lambda-runtime/src/simulated.rs deleted file mode 100644 index f6a06bca..00000000 --- a/lambda-runtime/src/simulated.rs +++ /dev/null @@ -1,100 +0,0 @@ -use http::Uri; -use hyper::client::connect::Connection; -use std::{ - collections::HashMap, - future::Future, - io::Result as IoResult, - pin::Pin, - sync::{Arc, Mutex}, - task::{Context, Poll}, -}; -use tokio::io::{AsyncRead, AsyncWrite, DuplexStream, ReadBuf}; - -use crate::Error; - -#[derive(Clone)] -pub struct Connector { - inner: Arc>>, -} - -pub struct DuplexStreamWrapper(DuplexStream); - -impl DuplexStreamWrapper { - pub(crate) fn new(stream: DuplexStream) -> DuplexStreamWrapper { - DuplexStreamWrapper(stream) - } -} - -impl Connector { - pub fn new() -> Self { - #[allow(clippy::mutable_key_type)] - let map = HashMap::new(); - Connector { - inner: Arc::new(Mutex::new(map)), - } - } - - pub fn insert(&self, uri: Uri, stream: DuplexStreamWrapper) -> Result<(), Error> { - match self.inner.lock() { - Ok(mut map) => { - map.insert(uri, stream); - Ok(()) - } - Err(_) => Err("mutex was poisoned".into()), - } - } - - pub fn with(uri: Uri, stream: DuplexStreamWrapper) -> Result { - let connector = Connector::new(); - match connector.insert(uri, stream) { - Ok(_) => Ok(connector), - Err(e) => Err(e), - } - } -} - -impl hyper::service::Service for Connector { - type Response = DuplexStreamWrapper; - type Error = crate::Error; - #[allow(clippy::type_complexity)] - type Future = Pin> + Send>>; - - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, uri: Uri) -> Self::Future { - let res = match self.inner.lock() { - Ok(mut map) if map.contains_key(&uri) => Ok(map.remove(&uri).unwrap()), - Ok(_) => Err(format!("Uri {uri} is not in map").into()), - Err(_) => Err("mutex was poisoned".into()), - }; - Box::pin(async move { res }) - } -} - -impl Connection for DuplexStreamWrapper { - fn connected(&self) -> hyper::client::connect::Connected { - hyper::client::connect::Connected::new() - } -} - -impl AsyncRead for DuplexStreamWrapper { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { - Pin::new(&mut self.0).poll_read(cx, buf) - } -} - -impl AsyncWrite for DuplexStreamWrapper { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { - Pin::new(&mut self.0).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_flush(cx) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_shutdown(cx) - } -}