Skip to content

Commit

Permalink
Add support for WebSockets over HTTP/2 (#2894)
Browse files Browse the repository at this point in the history
  • Loading branch information
SabrinaJewson authored Oct 6, 2024
1 parent d783a8b commit 64e6eda
Show file tree
Hide file tree
Showing 14 changed files with 375 additions and 87 deletions.
2 changes: 2 additions & 0 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **breaking:** Upgrade matchit to 0.8, changing the path parameter syntax from `/:single` and `/*many`
to `/{single}` and `/{*many}`; the old syntax produces a panic to avoid silent change in behavior ([#2645])
- **change:** Update minimum rust version to 1.75 ([#2943])
- **added:** Add support WebSockets over HTTP/2.
They can be enabled by changing `get(ws_endpoint)` handlers to `any(ws_endpoint)`.

[#2473]: https://github.com/tokio-rs/axum/pull/2473
[#2645]: https://github.com/tokio-rs/axum/pull/2645
Expand Down
1 change: 1 addition & 0 deletions axum/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ features = [
[dev-dependencies]
anyhow = "1.0"
axum-macros = { path = "../axum-macros", features = ["__private"] }
hyper = { version = "1.1.0", features = ["client"] }
quickcheck = "1.0"
quickcheck_macros = "1.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] }
Expand Down
212 changes: 156 additions & 56 deletions axum/src/extract/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
//! ```
//! use axum::{
//! extract::ws::{WebSocketUpgrade, WebSocket},
//! routing::get,
//! routing::any,
//! response::{IntoResponse, Response},
//! Router,
//! };
//!
//! let app = Router::new().route("/ws", get(handler));
//! let app = Router::new().route("/ws", any(handler));
//!
//! async fn handler(ws: WebSocketUpgrade) -> Response {
//! ws.on_upgrade(handle_socket)
Expand Down Expand Up @@ -40,7 +40,7 @@
//! use axum::{
//! extract::{ws::{WebSocketUpgrade, WebSocket}, State},
//! response::Response,
//! routing::get,
//! routing::any,
//! Router,
//! };
//!
Expand All @@ -58,7 +58,7 @@
//! }
//!
//! let app = Router::new()
//! .route("/ws", get(handler))
//! .route("/ws", any(handler))
//! .with_state(AppState { /* ... */ });
//! # let _: Router = app;
//! ```
Expand Down Expand Up @@ -101,7 +101,7 @@ use futures_util::{
use http::{
header::{self, HeaderMap, HeaderName, HeaderValue},
request::Parts,
Method, StatusCode,
Method, StatusCode, Version,
};
use hyper_util::rt::TokioIo;
use sha1::{Digest, Sha1};
Expand All @@ -121,17 +121,20 @@ use tokio_tungstenite::{

/// Extractor for establishing WebSocket connections.
///
/// Note: This extractor requires the request method to be `GET` so it should
/// always be used with [`get`](crate::routing::get). Requests with other methods will be
/// rejected.
/// For HTTP/1.1 requests, this extractor requires the request method to be `GET`;
/// in later versions, `CONNECT` is used instead.
/// To support both, it should be used with [`any`](crate::routing::any).
///
/// See the [module docs](self) for an example.
///
/// [`MethodFilter`]: crate::routing::MethodFilter
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
config: WebSocketConfig,
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
protocol: Option<HeaderValue>,
sec_websocket_key: HeaderValue,
/// `None` if HTTP/2+ WebSockets are used.
sec_websocket_key: Option<HeaderValue>,
on_upgrade: hyper::upgrade::OnUpgrade,
on_failed_upgrade: F,
sec_websocket_protocol: Option<HeaderValue>,
Expand Down Expand Up @@ -212,12 +215,12 @@ impl<F> WebSocketUpgrade<F> {
/// ```
/// use axum::{
/// extract::ws::{WebSocketUpgrade, WebSocket},
/// routing::get,
/// routing::any,
/// response::{IntoResponse, Response},
/// Router,
/// };
///
/// let app = Router::new().route("/ws", get(handler));
/// let app = Router::new().route("/ws", any(handler));
///
/// async fn handler(ws: WebSocketUpgrade) -> Response {
/// ws.protocols(["graphql-ws", "graphql-transport-ws"])
Expand Down Expand Up @@ -329,25 +332,34 @@ impl<F> WebSocketUpgrade<F> {
callback(socket).await;
});

#[allow(clippy::declare_interior_mutable_const)]
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
#[allow(clippy::declare_interior_mutable_const)]
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");

let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, UPGRADE)
.header(header::UPGRADE, WEBSOCKET)
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(self.sec_websocket_key.as_bytes()),
);

if let Some(protocol) = self.protocol {
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
if let Some(sec_websocket_key) = &self.sec_websocket_key {
// If `sec_websocket_key` was `Some`, we are using HTTP/1.1.

#[allow(clippy::declare_interior_mutable_const)]
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
#[allow(clippy::declare_interior_mutable_const)]
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");

let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, UPGRADE)
.header(header::UPGRADE, WEBSOCKET)
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(sec_websocket_key.as_bytes()),
);

if let Some(protocol) = self.protocol {
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
}

builder.body(Body::empty()).unwrap()
builder.body(Body::empty()).unwrap()
} else {
// Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just respond
// with a 2XX with an empty body:
// <https://datatracker.ietf.org/doc/html/rfc9113#name-the-connect-method>.
Response::new(Body::empty())
}
}
}

Expand Down Expand Up @@ -387,28 +399,49 @@ where
type Rejection = WebSocketUpgradeRejection;

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if parts.method != Method::GET {
return Err(MethodNotGet.into());
}
let sec_websocket_key = if parts.version <= Version::HTTP_11 {
if parts.method != Method::GET {
return Err(MethodNotGet.into());
}

if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
return Err(InvalidConnectionHeader.into());
}
if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
return Err(InvalidConnectionHeader.into());
}

if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
return Err(InvalidUpgradeHeader.into());
}
if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
return Err(InvalidUpgradeHeader.into());
}

Some(
parts
.headers
.get(header::SEC_WEBSOCKET_KEY)
.ok_or(WebSocketKeyHeaderMissing)?
.clone(),
)
} else {
if parts.method != Method::CONNECT {
return Err(MethodNotConnect.into());
}

// if this feature flag is disabled, we won’t be receiving an HTTP/2 request to begin
// with.
#[cfg(feature = "http2")]
if parts
.extensions
.get::<hyper::ext::Protocol>()
.map_or(true, |p| p.as_str() != "websocket")
{
return Err(InvalidProtocolPseudoheader.into());
}

None
};

if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
return Err(InvalidWebSocketVersionHeader.into());
}

let sec_websocket_key = parts
.headers
.get(header::SEC_WEBSOCKET_KEY)
.ok_or(WebSocketKeyHeaderMissing)?
.clone();

let on_upgrade = parts
.extensions
.remove::<hyper::upgrade::OnUpgrade>()
Expand Down Expand Up @@ -706,6 +739,13 @@ pub mod rejection {
pub struct MethodNotGet;
}

define_rejection! {
#[status = METHOD_NOT_ALLOWED]
#[body = "Request method must be `CONNECT`"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct MethodNotConnect;
}

define_rejection! {
#[status = BAD_REQUEST]
#[body = "Connection header did not include 'upgrade'"]
Expand All @@ -720,6 +760,13 @@ pub mod rejection {
pub struct InvalidUpgradeHeader;
}

define_rejection! {
#[status = BAD_REQUEST]
#[body = "`:protocol` pseudo-header did not include 'websocket'"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct InvalidProtocolPseudoheader;
}

define_rejection! {
#[status = BAD_REQUEST]
#[body = "`Sec-WebSocket-Version` header did not include '13'"]
Expand Down Expand Up @@ -755,8 +802,10 @@ pub mod rejection {
/// extractor can fail.
pub enum WebSocketUpgradeRejection {
MethodNotGet,
MethodNotConnect,
InvalidConnectionHeader,
InvalidUpgradeHeader,
InvalidProtocolPseudoheader,
InvalidWebSocketVersionHeader,
WebSocketKeyHeaderMissing,
ConnectionNotUpgradable,
Expand Down Expand Up @@ -838,14 +887,18 @@ mod tests {
use std::future::ready;

use super::*;
use crate::{routing::get, test_helpers::spawn_service, Router};
use crate::{routing::any, test_helpers::spawn_service, Router};
use http::{Request, Version};
use http_body_util::BodyExt as _;
use hyper_util::rt::TokioExecutor;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite;
use tower::ServiceExt;

#[crate::test]
async fn rejects_http_1_0_requests() {
let svc = get(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
let svc = any(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
let rejection = ws.unwrap_err();
assert!(matches!(
rejection,
Expand Down Expand Up @@ -874,7 +927,7 @@ mod tests {
async fn handler(ws: WebSocketUpgrade) -> Response {
ws.on_upgrade(|_| async {})
}
let _: Router = Router::new().route("/", get(handler));
let _: Router = Router::new().route("/", any(handler));
}

#[allow(dead_code)]
Expand All @@ -883,16 +936,61 @@ mod tests {
ws.on_failed_upgrade(|_error: Error| println!("oops!"))
.on_upgrade(|_| async {})
}
let _: Router = Router::new().route("/", get(handler));
let _: Router = Router::new().route("/", any(handler));
}

#[crate::test]
async fn integration_test() {
let app = Router::new().route(
"/echo",
get(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))),
);
let addr = spawn_service(echo_app());
let (socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo"))
.await
.unwrap();
test_echo_app(socket).await;
}

#[crate::test]
#[cfg(feature = "http2")]
async fn http2() {
let addr = spawn_service(echo_app());
let io = TokioIo::new(TcpStream::connect(addr).await.unwrap());
let (mut send_request, conn) =
hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.handshake(io)
.await
.unwrap();

// Wait a little for the SETTINGS frame to go through…
for _ in 0..10 {
tokio::task::yield_now().await;
}
assert!(conn.is_extended_connect_protocol_enabled());
tokio::spawn(async {
conn.await.unwrap();
});

let req = Request::builder()
.method(Method::CONNECT)
.extension(hyper::ext::Protocol::from_static("websocket"))
.uri("/echo")
.header("sec-websocket-version", "13")
.header("Host", "server.example.com")
.body(Body::empty())
.unwrap();

let response = send_request.send_request(req).await.unwrap();
let status = response.status();
if status != 200 {
let body = response.into_body().collect().await.unwrap().to_bytes();
let body = std::str::from_utf8(&body).unwrap();
panic!("response status was {}: {body}", status);
}
let upgraded = hyper::upgrade::on(response).await.unwrap();
let upgraded = TokioIo::new(upgraded);
let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await;
test_echo_app(socket).await;
}

fn echo_app() -> Router {
async fn handle_socket(mut socket: WebSocket) {
while let Some(Ok(msg)) = socket.recv().await {
match msg {
Expand All @@ -908,11 +1006,13 @@ mod tests {
}
}

let addr = spawn_service(app);
let (mut socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo"))
.await
.unwrap();
Router::new().route(
"/echo",
any(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))),
)
}

async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(mut socket: WebSocketStream<S>) {
let input = tungstenite::Message::Text("foobar".to_owned());
socket.send(input.clone()).await.unwrap();
let output = socket.next().await.unwrap().unwrap();
Expand Down
8 changes: 3 additions & 5 deletions axum/src/routing/method_routing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1035,13 +1035,11 @@ where
match $svc {
MethodEndpoint::None => {}
MethodEndpoint::Route(route) => {
return RouteFuture::from_future(route.clone().oneshot_inner($req))
.strip_body($method == Method::HEAD);
return route.clone().oneshot_inner($req);
}
MethodEndpoint::BoxedHandler(handler) => {
let route = handler.clone().into_route(state);
return RouteFuture::from_future(route.clone().oneshot_inner($req))
.strip_body($method == Method::HEAD);
let mut route = handler.clone().into_route(state);
return route.oneshot_inner($req);
}
}
}
Expand Down
Loading

0 comments on commit 64e6eda

Please sign in to comment.