Skip to content

Commit

Permalink
Add support for Lambda-Extesion-Accept-Feature header (#887)
Browse files Browse the repository at this point in the history
- Use this header to read the account id that the extesion is installed in.
- Keep the account id as optional just in case we make this feature optional in the future.
- Users should not rely on it being always present.
- Extract all the information that the register call provides.
  • Loading branch information
calavera authored Jun 2, 2024
1 parent 0fcba16 commit 1cf868c
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 8 deletions.
60 changes: 52 additions & 8 deletions lambda-extension/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use hyper::service::service_fn;

use hyper_util::rt::tokio::TokioIo;
use lambda_runtime_api_client::Client;
use serde::Deserialize;
use std::{
convert::Infallible, fmt, future::ready, future::Future, net::SocketAddr, path::PathBuf, pin::Pin, sync::Arc,
};
Expand Down Expand Up @@ -230,8 +231,7 @@ where
pub async fn register(self) -> Result<RegisteredExtension<E>, Error> {
let client = &Client::builder().build()?;

let extension_id = register(client, self.extension_name, self.events).await?;
let extension_id = extension_id.to_str()?;
let register_res = register(client, self.extension_name, self.events).await?;

// Logs API subscriptions must be requested during the Lambda init phase (see
// https://docs.aws.amazon.com/lambda/latest/dg/runtimes-logs-api.html#runtimes-logs-api-subscribing).
Expand Down Expand Up @@ -266,7 +266,7 @@ where
// Call Logs API to start receiving events
let req = requests::subscribe_request(
Api::LogsApi,
extension_id,
&register_res.extension_id,
self.log_types,
self.log_buffering,
self.log_port_number,
Expand Down Expand Up @@ -312,7 +312,7 @@ where
// Call Telemetry API to start receiving events
let req = requests::subscribe_request(
Api::TelemetryApi,
extension_id,
&register_res.extension_id,
self.telemetry_types,
self.telemetry_buffering,
self.telemetry_port_number,
Expand All @@ -326,7 +326,11 @@ where
}

Ok(RegisteredExtension {
extension_id: extension_id.to_string(),
extension_id: register_res.extension_id,
function_name: register_res.function_name,
function_version: register_res.function_version,
handler: register_res.handler,
account_id: register_res.account_id,
events_processor: self.events_processor,
})
}
Expand All @@ -339,7 +343,17 @@ where

/// An extension registered by calling [`Extension::register`].
pub struct RegisteredExtension<E> {
extension_id: String,
/// The ID of the registered extension. This ID is unique per extension and remains constant
pub extension_id: String,
/// The ID of the account the extension was registered to.
/// This will be `None` if the register request doesn't send the Lambda-Extension-Accept-Feature header
pub account_id: Option<String>,
/// The name of the Lambda function that the extension is registered with
pub function_name: String,
/// The version of the Lambda function that the extension is registered with
pub function_version: String,
/// The Lambda function handler that AWS Lambda invokes
pub handler: String,
events_processor: E,
}

Expand Down Expand Up @@ -468,12 +482,30 @@ where
}
}

#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct RegisterResponseBody {
function_name: String,
function_version: String,
handler: String,
account_id: Option<String>,
}

#[derive(Debug)]
struct RegisterResponse {
extension_id: String,
function_name: String,
function_version: String,
handler: String,
account_id: Option<String>,
}

/// Initialize and register the extension in the Extensions API
async fn register<'a>(
client: &'a Client,
extension_name: Option<&'a str>,
events: Option<&'a [&'a str]>,
) -> Result<http::HeaderValue, Error> {
) -> Result<RegisterResponse, Error> {
let name = match extension_name {
Some(name) => name.into(),
None => {
Expand Down Expand Up @@ -501,5 +533,17 @@ async fn register<'a>(
.get(requests::EXTENSION_ID_HEADER)
.ok_or_else(|| ExtensionError::boxed("missing extension id header"))
.map_err(|e| ExtensionError::boxed(e.to_string()))?;
Ok(header.clone())
let extension_id = header.to_str()?.to_string();

let (_, body) = res.into_parts();
let body = body.collect().await?.to_bytes();
let response: RegisterResponseBody = serde_json::from_slice(&body)?;

Ok(RegisterResponse {
extension_id,
function_name: response.function_name,
function_version: response.function_version,
handler: response.handler,
account_id: response.account_id,
})
}
6 changes: 6 additions & 0 deletions lambda-extension/src/requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ const EXTENSION_ERROR_TYPE_HEADER: &str = "Lambda-Extension-Function-Error-Type"
const CONTENT_TYPE_HEADER_NAME: &str = "Content-Type";
const CONTENT_TYPE_HEADER_VALUE: &str = "application/json";

// Comma separated list of features the extension supports.
// `accountId` is currently the only supported feature.
const EXTENSION_ACCEPT_FEATURE: &str = "Lambda-Extension-Accept-Feature";
const EXTENSION_ACCEPT_FEATURE_VALUE: &str = "accountId";

pub(crate) fn next_event_request(extension_id: &str) -> Result<Request<Body>, Error> {
let req = build_request()
.method(Method::GET)
Expand All @@ -25,6 +30,7 @@ pub(crate) fn register_request(extension_name: &str, events: &[&str]) -> Result<
.method(Method::POST)
.uri("/2020-01-01/extension/register")
.header(EXTENSION_NAME_HEADER, extension_name)
.header(EXTENSION_ACCEPT_FEATURE, EXTENSION_ACCEPT_FEATURE_VALUE)
.header(CONTENT_TYPE_HEADER_NAME, CONTENT_TYPE_HEADER_VALUE)
.body(Body::from(serde_json::to_string(&events)?))?;

Expand Down

0 comments on commit 1cf868c

Please sign in to comment.