Skip to content

Commit

Permalink
Remove function config allocations per invocation. (#732)
Browse files Browse the repository at this point in the history
Every invocation clones the function config. This allocates memory in the heap for no reason.

This change removes those extra allocations by wrapping the config into an Arc and sharing that between invocations.

Signed-off-by: David Calavera <[email protected]>
  • Loading branch information
calavera authored Nov 20, 2023
1 parent d3e365c commit 53637e7
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 43 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[workspace]
resolver = "2"
members = [
"lambda-http",
"lambda-integration-tests",
Expand Down
2 changes: 1 addition & 1 deletion lambda-runtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ hyper = { version = "0.14.20", features = [
"server",
] }
futures = "0.3"
serde = { version = "1", features = ["derive"] }
serde = { version = "1", features = ["derive", "rc"] }
serde_json = "^1"
bytes = "1.0"
http = "0.2"
Expand Down
28 changes: 18 additions & 10 deletions lambda-runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use std::{
future::Future,
marker::PhantomData,
panic,
sync::Arc,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_stream::{Stream, StreamExt};
Expand Down Expand Up @@ -58,6 +59,8 @@ pub struct Config {
pub log_group: String,
}

type RefConfig = Arc<Config>;

impl Config {
/// Attempts to read configuration from environment variables.
pub fn from_env() -> Result<Self, Error> {
Expand Down Expand Up @@ -86,7 +89,7 @@ where

struct Runtime<C: Service<http::Uri> = HttpConnector> {
client: Client<C>,
config: Config,
config: RefConfig,
}

impl<C> Runtime<C>
Expand Down Expand Up @@ -127,8 +130,7 @@ where
continue;
}

let ctx: Context = Context::try_from(parts.headers)?;
let ctx: Context = ctx.with_config(&self.config);
let ctx: Context = Context::try_from((self.config.clone(), parts.headers))?;
let request_id = &ctx.request_id.clone();

let request_span = match &ctx.xray_trace_id {
Expand Down Expand Up @@ -263,7 +265,10 @@ where
trace!("Loading config from env");
let config = Config::from_env()?;
let client = Client::builder().build().expect("Unable to create a runtime client");
let runtime = Runtime { client, config };
let runtime = Runtime {
client,
config: Arc::new(config),
};

let client = &runtime.client;
let incoming = incoming(client);
Expand Down Expand Up @@ -294,15 +299,15 @@ mod endpoint_tests {
},
simulated,
types::Diagnostic,
Error, Runtime,
Config, Error, Runtime,
};
use futures::future::BoxFuture;
use http::{uri::PathAndQuery, HeaderValue, Method, Request, Response, StatusCode, Uri};
use hyper::{server::conn::Http, service::service_fn, Body};
use lambda_runtime_api_client::Client;
use serde_json::json;
use simulated::DuplexStreamWrapper;
use std::{convert::TryFrom, env, marker::PhantomData};
use std::{convert::TryFrom, env, marker::PhantomData, sync::Arc};
use tokio::{
io::{self, AsyncRead, AsyncWrite},
select,
Expand Down Expand Up @@ -531,9 +536,12 @@ mod endpoint_tests {
if env::var("AWS_LAMBDA_LOG_GROUP_NAME").is_err() {
env::set_var("AWS_LAMBDA_LOG_GROUP_NAME", "test_log");
}
let config = crate::Config::from_env().expect("Failed to read env vars");
let config = Config::from_env().expect("Failed to read env vars");

let runtime = Runtime { client, config };
let runtime = Runtime {
client,
config: Arc::new(config),
};
let client = &runtime.client;
let incoming = incoming(client).take(1);
runtime.run(incoming, f).await?;
Expand Down Expand Up @@ -568,13 +576,13 @@ mod endpoint_tests {

let f = crate::service_fn(func);

let config = crate::Config {
let config = Arc::new(Config {
function_name: "test_fn".to_string(),
memory: 128,
version: "1".to_string(),
log_stream: "test_stream".to_string(),
log_group: "test_log".to_string(),
};
});

let runtime = Runtime { client, config };
let client = &runtime.client;
Expand Down
81 changes: 49 additions & 32 deletions lambda-runtime/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{Config, Error};
use crate::{Error, RefConfig};
use base64::prelude::*;
use bytes::Bytes;
use http::{HeaderMap, HeaderValue, StatusCode};
Expand Down Expand Up @@ -97,7 +97,7 @@ pub struct CognitoIdentity {
/// are populated using the [Lambda environment variables](https://docs.aws.amazon.com/lambda/latest/dg/current-supported-versions.html)
/// and [the headers returned by the poll request to the Runtime APIs](https://docs.aws.amazon.com/lambda/latest/dg/runtimes-api.html#runtimes-api-next).
#[non_exhaustive]
#[derive(Clone, Debug, Eq, PartialEq, Default, Serialize, Deserialize)]
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct Context {
/// The AWS request ID generated by the Lambda service.
pub request_id: String,
Expand All @@ -117,12 +117,14 @@ pub struct Context {
/// Lambda function configuration from the local environment variables.
/// Includes information such as the function name, memory allocation,
/// version, and log streams.
pub env_config: Config,
pub env_config: RefConfig,
}

impl TryFrom<HeaderMap> for Context {
impl TryFrom<(RefConfig, HeaderMap)> for Context {
type Error = Error;
fn try_from(headers: HeaderMap) -> Result<Self, Self::Error> {
fn try_from(data: (RefConfig, HeaderMap)) -> Result<Self, Self::Error> {
let env_config = data.0;
let headers = data.1;
let client_context: Option<ClientContext> = if let Some(value) = headers.get("lambda-runtime-client-context") {
serde_json::from_str(value.to_str()?)?
} else {
Expand Down Expand Up @@ -158,13 +160,20 @@ impl TryFrom<HeaderMap> for Context {
.map(|v| String::from_utf8_lossy(v.as_bytes()).to_string()),
client_context,
identity,
..Default::default()
env_config,
};

Ok(ctx)
}
}

impl Context {
/// The execution deadline for the current invocation.
pub fn deadline(&self) -> SystemTime {
SystemTime::UNIX_EPOCH + Duration::from_millis(self.deadline)
}
}

/// Incoming Lambda request containing the event payload and context.
#[derive(Clone, Debug)]
pub struct LambdaEvent<T> {
Expand Down Expand Up @@ -273,6 +282,8 @@ where
#[cfg(test)]
mod test {
use super::*;
use crate::Config;
use std::sync::Arc;

#[test]
fn round_trip_lambda_error() {
Expand All @@ -292,6 +303,8 @@ mod test {

#[test]
fn context_with_expected_values_and_types_resolves() {
let config = Arc::new(Config::default());

let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
Expand All @@ -300,16 +313,18 @@ mod test {
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
let tried = Context::try_from(headers);
let tried = Context::try_from((config, headers));
assert!(tried.is_ok());
}

#[test]
fn context_with_certain_missing_headers_still_resolves() {
let config = Arc::new(Config::default());

let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
let tried = Context::try_from(headers);
let tried = Context::try_from((config, headers));
assert!(tried.is_ok());
}

Expand Down Expand Up @@ -338,7 +353,9 @@ mod test {
"lambda-runtime-client-context",
HeaderValue::from_str(&client_context_str).unwrap(),
);
let tried = Context::try_from(headers);

let config = Arc::new(Config::default());
let tried = Context::try_from((config, headers));
assert!(tried.is_ok());
let tried = tried.unwrap();
assert!(tried.client_context.is_some());
Expand All @@ -347,17 +364,20 @@ mod test {

#[test]
fn context_with_empty_client_context_resolves() {
let config = Arc::new(Config::default());
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert("lambda-runtime-client-context", HeaderValue::from_static("{}"));
let tried = Context::try_from(headers);
let tried = Context::try_from((config, headers));
assert!(tried.is_ok());
assert!(tried.unwrap().client_context.is_some());
}

#[test]
fn context_with_identity_resolves() {
let config = Arc::new(Config::default());

let cognito_identity = CognitoIdentity {
identity_id: String::new(),
identity_pool_id: String::new(),
Expand All @@ -370,7 +390,7 @@ mod test {
"lambda-runtime-cognito-identity",
HeaderValue::from_str(&cognito_identity_str).unwrap(),
);
let tried = Context::try_from(headers);
let tried = Context::try_from((config, headers));
assert!(tried.is_ok());
let tried = tried.unwrap();
assert!(tried.identity.is_some());
Expand All @@ -379,6 +399,8 @@ mod test {

#[test]
fn context_with_bad_deadline_type_is_err() {
let config = Arc::new(Config::default());

let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert(
Expand All @@ -390,86 +412,81 @@ mod test {
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
let tried = Context::try_from(headers);
let tried = Context::try_from((config, headers));
assert!(tried.is_err());
}

#[test]
fn context_with_bad_client_context_is_err() {
let config = Arc::new(Config::default());

let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert(
"lambda-runtime-client-context",
HeaderValue::from_static("BAD-Type,not JSON"),
);
let tried = Context::try_from(headers);
let tried = Context::try_from((config, headers));
assert!(tried.is_err());
}

#[test]
fn context_with_empty_identity_is_err() {
let config = Arc::new(Config::default());

let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert("lambda-runtime-cognito-identity", HeaderValue::from_static("{}"));
let tried = Context::try_from(headers);
let tried = Context::try_from((config, headers));
assert!(tried.is_err());
}

#[test]
fn context_with_bad_identity_is_err() {
let config = Arc::new(Config::default());

let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert(
"lambda-runtime-cognito-identity",
HeaderValue::from_static("BAD-Type,not JSON"),
);
let tried = Context::try_from(headers);
let tried = Context::try_from((config, headers));
assert!(tried.is_err());
}

#[test]
#[should_panic]
#[allow(unused_must_use)]
fn context_with_missing_request_id_should_panic() {
let config = Arc::new(Config::default());

let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert(
"lambda-runtime-invoked-function-arn",
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
Context::try_from(headers);
Context::try_from((config, headers));
}

#[test]
#[should_panic]
#[allow(unused_must_use)]
fn context_with_missing_deadline_should_panic() {
let config = Arc::new(Config::default());

let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert(
"lambda-runtime-invoked-function-arn",
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
Context::try_from(headers);
}
}

impl Context {
/// Add environment details to the context by setting `env_config`.
pub fn with_config(self, config: &Config) -> Self {
Self {
env_config: config.clone(),
..self
}
}

/// The execution deadline for the current invocation.
pub fn deadline(&self) -> SystemTime {
SystemTime::UNIX_EPOCH + Duration::from_millis(self.deadline)
Context::try_from((config, headers));
}
}

0 comments on commit 53637e7

Please sign in to comment.