Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract the request ID without allocating extra memory. #735

Merged
merged 1 commit into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 5 additions & 14 deletions lambda-runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use hyper::{
use lambda_runtime_api_client::Client;
use serde::{Deserialize, Serialize};
use std::{
convert::TryFrom,
env,
fmt::{self, Debug, Display},
future::Future,
Expand All @@ -41,6 +40,8 @@ mod types;
use requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEventRequest};
pub use types::{Context, FunctionResponse, IntoFunctionResponse, LambdaEvent, MetadataPrelude, StreamResponse};

use types::invoke_request_id;

/// Error type that lambdas may result in
pub type Error = lambda_runtime_api_client::Error;

Expand Down Expand Up @@ -121,6 +122,7 @@ where
trace!("New event arrived (run loop)");
let event = next_event_response?;
let (parts, body) = event.into_parts();
let request_id = invoke_request_id(&parts.headers)?;

#[cfg(debug_assertions)]
if parts.status == http::StatusCode::NO_CONTENT {
Expand All @@ -130,19 +132,8 @@ where
continue;
}

let ctx: Context = Context::try_from((self.config.clone(), parts.headers))?;
let request_id = &ctx.request_id.clone();

let request_span = match &ctx.xray_trace_id {
Some(trace_id) => {
env::set_var("_X_AMZN_TRACE_ID", trace_id);
tracing::info_span!("Lambda runtime invoke", requestId = request_id, xrayTraceId = trace_id)
}
None => {
env::remove_var("_X_AMZN_TRACE_ID");
tracing::info_span!("Lambda runtime invoke", requestId = request_id)
}
};
let ctx: Context = Context::new(request_id, self.config.clone(), &parts.headers)?;
let request_span = ctx.request_span();

// Group the handling in one future and instrument it with the span
async {
Expand Down
93 changes: 62 additions & 31 deletions lambda-runtime/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use crate::{Error, RefConfig};
use base64::prelude::*;
use bytes::Bytes;
use http::{HeaderMap, HeaderValue, StatusCode};
use http::{header::ToStrError, HeaderMap, HeaderValue, StatusCode};
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
convert::TryFrom,
env,
fmt::Debug,
time::{Duration, SystemTime},
};
use tokio_stream::Stream;
use tracing::Span;

#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
Expand Down Expand Up @@ -120,11 +121,10 @@ pub struct Context {
pub env_config: RefConfig,
}

impl TryFrom<(RefConfig, HeaderMap)> for Context {
type Error = Error;
fn try_from(data: (RefConfig, HeaderMap)) -> Result<Self, Self::Error> {
let env_config = data.0;
let headers = data.1;
impl Context {
/// Create a new [Context] struct based on the fuction configuration
/// and the incoming request data.
pub fn new(request_id: &str, env_config: RefConfig, headers: &HeaderMap) -> Result<Self, Error> {
let client_context: Option<ClientContext> = if let Some(value) = headers.get("lambda-runtime-client-context") {
serde_json::from_str(value.to_str()?)?
} else {
Expand All @@ -138,11 +138,7 @@ impl TryFrom<(RefConfig, HeaderMap)> for Context {
};

let ctx = Context {
request_id: headers
.get("lambda-runtime-aws-request-id")
.expect("missing lambda-runtime-aws-request-id header")
.to_str()?
.to_owned(),
request_id: request_id.to_owned(),
deadline: headers
.get("lambda-runtime-deadline-ms")
.expect("missing lambda-runtime-deadline-ms header")
Expand All @@ -165,13 +161,37 @@ impl TryFrom<(RefConfig, HeaderMap)> for Context {

Ok(ctx)
}
}

impl Context {
/// The execution deadline for the current invocation.
pub fn deadline(&self) -> SystemTime {
SystemTime::UNIX_EPOCH + Duration::from_millis(self.deadline)
}

/// Create a new [`tracing::Span`] for an incoming invocation.
pub(crate) fn request_span(&self) -> Span {
match &self.xray_trace_id {
Some(trace_id) => {
env::set_var("_X_AMZN_TRACE_ID", trace_id);
tracing::info_span!(
"Lambda runtime invoke",
requestId = &self.request_id,
xrayTraceId = trace_id
)
}
None => {
env::remove_var("_X_AMZN_TRACE_ID");
tracing::info_span!("Lambda runtime invoke", requestId = &self.request_id)
}
}
}
}

/// Extract the invocation request id from the incoming request.
pub(crate) fn invoke_request_id(headers: &HeaderMap) -> Result<&str, ToStrError> {
headers
.get("lambda-runtime-aws-request-id")
.expect("missing lambda-runtime-aws-request-id header")
.to_str()
}

/// Incoming Lambda request containing the event payload and context.
Expand Down Expand Up @@ -313,7 +333,7 @@ mod test {
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
let tried = Context::try_from((config, headers));
let tried = Context::new("id", config, &headers);
assert!(tried.is_ok());
}

Expand All @@ -324,7 +344,7 @@ mod test {
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
let tried = Context::try_from((config, headers));
let tried = Context::new("id", config, &headers);
assert!(tried.is_ok());
}

Expand Down Expand Up @@ -355,7 +375,7 @@ mod test {
);

let config = Arc::new(Config::default());
let tried = Context::try_from((config, headers));
let tried = Context::new("id", config, &headers);
assert!(tried.is_ok());
let tried = tried.unwrap();
assert!(tried.client_context.is_some());
Expand All @@ -369,7 +389,7 @@ mod test {
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert("lambda-runtime-client-context", HeaderValue::from_static("{}"));
let tried = Context::try_from((config, headers));
let tried = Context::new("id", config, &headers);
assert!(tried.is_ok());
assert!(tried.unwrap().client_context.is_some());
}
Expand All @@ -390,7 +410,7 @@ mod test {
"lambda-runtime-cognito-identity",
HeaderValue::from_str(&cognito_identity_str).unwrap(),
);
let tried = Context::try_from((config, headers));
let tried = Context::new("id", config, &headers);
assert!(tried.is_ok());
let tried = tried.unwrap();
assert!(tried.identity.is_some());
Expand All @@ -412,7 +432,7 @@ mod test {
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
let tried = Context::try_from((config, headers));
let tried = Context::new("id", config, &headers);
assert!(tried.is_err());
}

Expand All @@ -427,7 +447,7 @@ mod test {
"lambda-runtime-client-context",
HeaderValue::from_static("BAD-Type,not JSON"),
);
let tried = Context::try_from((config, headers));
let tried = Context::new("id", config, &headers);
assert!(tried.is_err());
}

Expand All @@ -439,7 +459,7 @@ mod test {
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert("lambda-runtime-cognito-identity", HeaderValue::from_static("{}"));
let tried = Context::try_from((config, headers));
let tried = Context::new("id", config, &headers);
assert!(tried.is_err());
}

Expand All @@ -454,14 +474,13 @@ mod test {
"lambda-runtime-cognito-identity",
HeaderValue::from_static("BAD-Type,not JSON"),
);
let tried = Context::try_from((config, headers));
let tried = Context::new("id", config, &headers);
assert!(tried.is_err());
}

#[test]
#[should_panic]
#[allow(unused_must_use)]
fn context_with_missing_request_id_should_panic() {
fn context_with_missing_deadline_should_panic() {
let config = Arc::new(Config::default());

let mut headers = HeaderMap::new();
Expand All @@ -471,22 +490,34 @@ mod test {
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
Context::try_from((config, headers));
let _ = Context::new("id", config, &headers);
}

#[test]
#[should_panic]
#[allow(unused_must_use)]
fn context_with_missing_deadline_should_panic() {
let config = Arc::new(Config::default());
fn invoke_request_id_should_not_panic() {
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert(
"lambda-runtime-invoked-function-arn",
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));

let _ = invoke_request_id(&headers);
}

#[test]
#[should_panic]
fn invoke_request_id_should_panic() {
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert(
"lambda-runtime-invoked-function-arn",
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
Context::try_from((config, headers));

let _ = invoke_request_id(&headers);
}
}