Skip to content

Commit

Permalink
Extract the request ID without allocating extra memory. (#735)
Browse files Browse the repository at this point in the history
Changes the way that the Context is initialized to receive the request ID as an argument. This way we also avoid allocating additional memory for it.

Signed-off-by: David Calavera <[email protected]>
  • Loading branch information
calavera authored Nov 28, 2023
1 parent 53637e7 commit b7df6fc
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 45 deletions.
19 changes: 5 additions & 14 deletions lambda-runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use hyper::{
use lambda_runtime_api_client::Client;
use serde::{Deserialize, Serialize};
use std::{
convert::TryFrom,
env,
fmt::{self, Debug, Display},
future::Future,
Expand All @@ -41,6 +40,8 @@ mod types;
use requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEventRequest};
pub use types::{Context, FunctionResponse, IntoFunctionResponse, LambdaEvent, MetadataPrelude, StreamResponse};

use types::invoke_request_id;

/// Error type that lambdas may result in
pub type Error = lambda_runtime_api_client::Error;

Expand Down Expand Up @@ -121,6 +122,7 @@ where
trace!("New event arrived (run loop)");
let event = next_event_response?;
let (parts, body) = event.into_parts();
let request_id = invoke_request_id(&parts.headers)?;

#[cfg(debug_assertions)]
if parts.status == http::StatusCode::NO_CONTENT {
Expand All @@ -130,19 +132,8 @@ where
continue;
}

let ctx: Context = Context::try_from((self.config.clone(), parts.headers))?;
let request_id = &ctx.request_id.clone();

let request_span = match &ctx.xray_trace_id {
Some(trace_id) => {
env::set_var("_X_AMZN_TRACE_ID", trace_id);
tracing::info_span!("Lambda runtime invoke", requestId = request_id, xrayTraceId = trace_id)
}
None => {
env::remove_var("_X_AMZN_TRACE_ID");
tracing::info_span!("Lambda runtime invoke", requestId = request_id)
}
};
let ctx: Context = Context::new(request_id, self.config.clone(), &parts.headers)?;
let request_span = ctx.request_span();

// Group the handling in one future and instrument it with the span
async {
Expand Down
93 changes: 62 additions & 31 deletions lambda-runtime/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use crate::{Error, RefConfig};
use base64::prelude::*;
use bytes::Bytes;
use http::{HeaderMap, HeaderValue, StatusCode};
use http::{header::ToStrError, HeaderMap, HeaderValue, StatusCode};
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
convert::TryFrom,
env,
fmt::Debug,
time::{Duration, SystemTime},
};
use tokio_stream::Stream;
use tracing::Span;

#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
Expand Down Expand Up @@ -120,11 +121,10 @@ pub struct Context {
pub env_config: RefConfig,
}

impl TryFrom<(RefConfig, HeaderMap)> for Context {
type Error = Error;
fn try_from(data: (RefConfig, HeaderMap)) -> Result<Self, Self::Error> {
let env_config = data.0;
let headers = data.1;
impl Context {
/// Create a new [Context] struct based on the fuction configuration
/// and the incoming request data.
pub fn new(request_id: &str, env_config: RefConfig, headers: &HeaderMap) -> Result<Self, Error> {
let client_context: Option<ClientContext> = if let Some(value) = headers.get("lambda-runtime-client-context") {
serde_json::from_str(value.to_str()?)?
} else {
Expand All @@ -138,11 +138,7 @@ impl TryFrom<(RefConfig, HeaderMap)> for Context {
};

let ctx = Context {
request_id: headers
.get("lambda-runtime-aws-request-id")
.expect("missing lambda-runtime-aws-request-id header")
.to_str()?
.to_owned(),
request_id: request_id.to_owned(),
deadline: headers
.get("lambda-runtime-deadline-ms")
.expect("missing lambda-runtime-deadline-ms header")
Expand All @@ -165,13 +161,37 @@ impl TryFrom<(RefConfig, HeaderMap)> for Context {

Ok(ctx)
}
}

impl Context {
/// The execution deadline for the current invocation.
pub fn deadline(&self) -> SystemTime {
SystemTime::UNIX_EPOCH + Duration::from_millis(self.deadline)
}

/// Create a new [`tracing::Span`] for an incoming invocation.
pub(crate) fn request_span(&self) -> Span {
match &self.xray_trace_id {
Some(trace_id) => {
env::set_var("_X_AMZN_TRACE_ID", trace_id);
tracing::info_span!(
"Lambda runtime invoke",
requestId = &self.request_id,
xrayTraceId = trace_id
)
}
None => {
env::remove_var("_X_AMZN_TRACE_ID");
tracing::info_span!("Lambda runtime invoke", requestId = &self.request_id)
}
}
}
}

/// Extract the invocation request id from the incoming request.
pub(crate) fn invoke_request_id(headers: &HeaderMap) -> Result<&str, ToStrError> {
headers
.get("lambda-runtime-aws-request-id")
.expect("missing lambda-runtime-aws-request-id header")
.to_str()
}

/// Incoming Lambda request containing the event payload and context.
Expand Down Expand Up @@ -313,7 +333,7 @@ mod test {
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
let tried = Context::try_from((config, headers));
let tried = Context::new("id", config, &headers);
assert!(tried.is_ok());
}

Expand All @@ -324,7 +344,7 @@ mod test {
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
let tried = Context::try_from((config, headers));
let tried = Context::new("id", config, &headers);
assert!(tried.is_ok());
}

Expand Down Expand Up @@ -355,7 +375,7 @@ mod test {
);

let config = Arc::new(Config::default());
let tried = Context::try_from((config, headers));
let tried = Context::new("id", config, &headers);
assert!(tried.is_ok());
let tried = tried.unwrap();
assert!(tried.client_context.is_some());
Expand All @@ -369,7 +389,7 @@ mod test {
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert("lambda-runtime-client-context", HeaderValue::from_static("{}"));
let tried = Context::try_from((config, headers));
let tried = Context::new("id", config, &headers);
assert!(tried.is_ok());
assert!(tried.unwrap().client_context.is_some());
}
Expand All @@ -390,7 +410,7 @@ mod test {
"lambda-runtime-cognito-identity",
HeaderValue::from_str(&cognito_identity_str).unwrap(),
);
let tried = Context::try_from((config, headers));
let tried = Context::new("id", config, &headers);
assert!(tried.is_ok());
let tried = tried.unwrap();
assert!(tried.identity.is_some());
Expand All @@ -412,7 +432,7 @@ mod test {
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
let tried = Context::try_from((config, headers));
let tried = Context::new("id", config, &headers);
assert!(tried.is_err());
}

Expand All @@ -427,7 +447,7 @@ mod test {
"lambda-runtime-client-context",
HeaderValue::from_static("BAD-Type,not JSON"),
);
let tried = Context::try_from((config, headers));
let tried = Context::new("id", config, &headers);
assert!(tried.is_err());
}

Expand All @@ -439,7 +459,7 @@ mod test {
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert("lambda-runtime-cognito-identity", HeaderValue::from_static("{}"));
let tried = Context::try_from((config, headers));
let tried = Context::new("id", config, &headers);
assert!(tried.is_err());
}

Expand All @@ -454,14 +474,13 @@ mod test {
"lambda-runtime-cognito-identity",
HeaderValue::from_static("BAD-Type,not JSON"),
);
let tried = Context::try_from((config, headers));
let tried = Context::new("id", config, &headers);
assert!(tried.is_err());
}

#[test]
#[should_panic]
#[allow(unused_must_use)]
fn context_with_missing_request_id_should_panic() {
fn context_with_missing_deadline_should_panic() {
let config = Arc::new(Config::default());

let mut headers = HeaderMap::new();
Expand All @@ -471,22 +490,34 @@ mod test {
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
Context::try_from((config, headers));
let _ = Context::new("id", config, &headers);
}

#[test]
#[should_panic]
#[allow(unused_must_use)]
fn context_with_missing_deadline_should_panic() {
let config = Arc::new(Config::default());
fn invoke_request_id_should_not_panic() {
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert(
"lambda-runtime-invoked-function-arn",
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));

let _ = invoke_request_id(&headers);
}

#[test]
#[should_panic]
fn invoke_request_id_should_panic() {
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert(
"lambda-runtime-invoked-function-arn",
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
Context::try_from((config, headers));

let _ = invoke_request_id(&headers);
}
}

0 comments on commit b7df6fc

Please sign in to comment.