Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for WebSockets over HTTP/2 #2894

Merged
merged 6 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions axum/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ features = [
[dev-dependencies]
anyhow = "1.0"
axum-macros = { path = "../axum-macros", version = "0.4.1", features = ["__private"] }
hyper = { version = "1.1.0", features = ["client"] }
quickcheck = "1.0"
quickcheck_macros = "1.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] }
Expand Down
207 changes: 156 additions & 51 deletions axum/src/extract/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
//! ```
//! use axum::{
//! extract::ws::{WebSocketUpgrade, WebSocket},
//! routing::get,
//! routing::any,
//! response::{IntoResponse, Response},
//! Router,
//! };
//!
//! let app = Router::new().route("/ws", get(handler));
//! let app = Router::new().route("/ws", any(handler));
//!
//! async fn handler(ws: WebSocketUpgrade) -> Response {
//! ws.on_upgrade(handle_socket)
Expand Down Expand Up @@ -40,7 +40,7 @@
//! use axum::{
//! extract::{ws::{WebSocketUpgrade, WebSocket}, State},
//! response::Response,
//! routing::get,
//! routing::any,
//! Router,
//! };
//!
Expand All @@ -58,7 +58,7 @@
//! }
//!
//! let app = Router::new()
//! .route("/ws", get(handler))
//! .route("/ws", any(handler))
//! .with_state(AppState { /* ... */ });
//! # let _: Router = app;
//! ```
Expand Down Expand Up @@ -102,7 +102,7 @@ use futures_util::{
use http::{
header::{self, HeaderMap, HeaderName, HeaderValue},
request::Parts,
Method, StatusCode,
Method, StatusCode, Version,
};
use hyper_util::rt::TokioIo;
use sha1::{Digest, Sha1};
Expand All @@ -122,17 +122,21 @@ use tokio_tungstenite::{

/// Extractor for establishing WebSocket connections.
///
/// Note: This extractor requires the request method to be `GET` so it should
/// always be used with [`get`](crate::routing::get). Requests with other methods will be
/// rejected.
/// For HTTP/1.1 requests, this extractor requires the request method to be `GET`;
/// in later versions, `CONNECT` is used instead. Thus it should either be used
/// with [`any`](crate::routing::any), or placed behind
/// [`on`](crate::routing::on)`(`[`MethodFilter`]`::GET.or(`[`MethodFilter`]`::POST), ...)`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we're using POST here because we don't support CONNECT directly?
Do you think that we should add CONNECT as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that’s just a typo on my part…

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I do now reälize that MethodFilter::CONNECT doesn’t exist. I pushed a commit to consistently use any instead.

///
/// See the [module docs](self) for an example.
///
/// [`MethodFilter`]: crate::routing::MethodFilter
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
config: WebSocketConfig,
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
protocol: Option<HeaderValue>,
sec_websocket_key: HeaderValue,
/// `None` if HTTP/2+ WebSockets are used.
sec_websocket_key: Option<HeaderValue>,
on_upgrade: hyper::upgrade::OnUpgrade,
on_failed_upgrade: F,
sec_websocket_protocol: Option<HeaderValue>,
Expand Down Expand Up @@ -330,25 +334,34 @@ impl<F> WebSocketUpgrade<F> {
callback(socket).await;
});

#[allow(clippy::declare_interior_mutable_const)]
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
#[allow(clippy::declare_interior_mutable_const)]
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");

let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, UPGRADE)
.header(header::UPGRADE, WEBSOCKET)
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(self.sec_websocket_key.as_bytes()),
);

if let Some(protocol) = self.protocol {
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
if let Some(sec_websocket_key) = &self.sec_websocket_key {
// If `sec_websocket_key` was `Some`, we are using HTTP/1.1.

#[allow(clippy::declare_interior_mutable_const)]
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
#[allow(clippy::declare_interior_mutable_const)]
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");

let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, UPGRADE)
.header(header::UPGRADE, WEBSOCKET)
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(sec_websocket_key.as_bytes()),
);

if let Some(protocol) = self.protocol {
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
}

builder.body(Body::empty()).unwrap()
builder.body(Body::empty()).unwrap()
} else {
// Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just respond
// with a 2XX with an empty body:
// <https://datatracker.ietf.org/doc/html/rfc9113#name-the-connect-method>.
Response::new(Body::empty())
}
}
}

Expand Down Expand Up @@ -389,28 +402,49 @@ where
type Rejection = WebSocketUpgradeRejection;

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if parts.method != Method::GET {
return Err(MethodNotGet.into());
}
let sec_websocket_key = if parts.version <= Version::HTTP_11 {
if parts.method != Method::GET {
return Err(MethodNotGet.into());
}

if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
return Err(InvalidConnectionHeader.into());
}
if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
return Err(InvalidConnectionHeader.into());
}

if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
return Err(InvalidUpgradeHeader.into());
}
if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
return Err(InvalidUpgradeHeader.into());
}

Some(
parts
.headers
.get(header::SEC_WEBSOCKET_KEY)
.ok_or(WebSocketKeyHeaderMissing)?
.clone(),
)
} else {
if parts.method != Method::CONNECT {
return Err(MethodNotConnect.into());
}

// if this feature flag is disabled, we won’t be receiving an HTTP/2 request to begin
// with.
#[cfg(feature = "http2")]
if parts
.extensions
.get::<hyper::ext::Protocol>()
.map_or(true, |p| p.as_str() != "websocket")
{
return Err(InvalidProtocolPseudoheader.into());
}

None
};

if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
return Err(InvalidWebSocketVersionHeader.into());
}

let sec_websocket_key = parts
.headers
.get(header::SEC_WEBSOCKET_KEY)
.ok_or(WebSocketKeyHeaderMissing)?
.clone();

let on_upgrade = parts
.extensions
.remove::<hyper::upgrade::OnUpgrade>()
Expand Down Expand Up @@ -708,6 +742,13 @@ pub mod rejection {
pub struct MethodNotGet;
}

define_rejection! {
#[status = METHOD_NOT_ALLOWED]
#[body = "Request method must be `CONNECT`"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct MethodNotConnect;
}

define_rejection! {
#[status = BAD_REQUEST]
#[body = "Connection header did not include 'upgrade'"]
Expand All @@ -722,6 +763,13 @@ pub mod rejection {
pub struct InvalidUpgradeHeader;
}

define_rejection! {
#[status = BAD_REQUEST]
#[body = "`:protocol` pseudo-header did not include 'websocket'"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct InvalidProtocolPseudoheader;
}

define_rejection! {
#[status = BAD_REQUEST]
#[body = "`Sec-WebSocket-Version` header did not include '13'"]
Expand Down Expand Up @@ -757,8 +805,10 @@ pub mod rejection {
/// extractor can fail.
pub enum WebSocketUpgradeRejection {
MethodNotGet,
MethodNotConnect,
InvalidConnectionHeader,
InvalidUpgradeHeader,
InvalidProtocolPseudoheader,
InvalidWebSocketVersionHeader,
WebSocketKeyHeaderMissing,
ConnectionNotUpgradable,
Expand Down Expand Up @@ -833,8 +883,16 @@ mod tests {
use std::future::ready;

use super::*;
use crate::{routing::get, test_helpers::spawn_service, Router};
use crate::{
routing::{any, get},
test_helpers::spawn_service,
Router,
};
use http::{Request, Version};
use http_body_util::BodyExt as _;
use hyper_util::rt::TokioExecutor;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite;
use tower::ServiceExt;

Expand Down Expand Up @@ -883,11 +941,56 @@ mod tests {

#[crate::test]
async fn integration_test() {
let app = Router::new().route(
"/echo",
get(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))),
);
let addr = spawn_service(echo_app());
let (socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo"))
.await
.unwrap();
test_echo_app(socket).await;
}

#[crate::test]
#[cfg(feature = "http2")]
async fn http2() {
let addr = spawn_service(echo_app());
let io = TokioIo::new(TcpStream::connect(addr).await.unwrap());
let (mut send_request, conn) =
hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.handshake(io)
.await
.unwrap();

// Wait a little for the SETTINGS frame to go through…
for _ in 0..10 {
tokio::task::yield_now().await;
}
assert!(conn.is_extended_connect_protocol_enabled());
tokio::spawn(async {
conn.await.unwrap();
});

let req = Request::builder()
.method(Method::CONNECT)
.extension(hyper::ext::Protocol::from_static("websocket"))
.uri("/echo")
.header("sec-websocket-version", "13")
.header("Host", "server.example.com")
.body(Body::empty())
.unwrap();

let response = send_request.send_request(req).await.unwrap();
let status = response.status();
if status != 200 {
let body = response.into_body().collect().await.unwrap().to_bytes();
let body = std::str::from_utf8(&body).unwrap();
panic!("response status was {}: {body}", status);
}
let upgraded = hyper::upgrade::on(response).await.unwrap();
let upgraded = TokioIo::new(upgraded);
let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await;
test_echo_app(socket).await;
}

fn echo_app() -> Router {
async fn handle_socket(mut socket: WebSocket) {
while let Some(Ok(msg)) = socket.recv().await {
match msg {
Expand All @@ -903,11 +1006,13 @@ mod tests {
}
}

let addr = spawn_service(app);
let (mut socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo"))
.await
.unwrap();
Router::new().route(
"/echo",
any(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))),
)
}

async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(mut socket: WebSocketStream<S>) {
let input = tungstenite::Message::Text("foobar".to_owned());
socket.send(input.clone()).await.unwrap();
let output = socket.next().await.unwrap().unwrap();
Expand Down
8 changes: 3 additions & 5 deletions axum/src/routing/method_routing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1034,13 +1034,11 @@ where
match $svc {
MethodEndpoint::None => {}
MethodEndpoint::Route(route) => {
return RouteFuture::from_future(route.clone().oneshot_inner($req))
.strip_body($method == Method::HEAD);
return route.clone().oneshot_inner($req);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that this is the same as #2897.
Do you think we should first merge #2897 to reduce the scope of this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both PRs benefit from this change, so it can be merged in either order!

}
MethodEndpoint::BoxedHandler(handler) => {
let route = handler.clone().into_route(state);
return RouteFuture::from_future(route.clone().oneshot_inner($req))
.strip_body($method == Method::HEAD);
let mut route = handler.clone().into_route(state);
return route.oneshot_inner($req);
}
}
}
Expand Down
6 changes: 2 additions & 4 deletions axum/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -658,12 +658,10 @@ where

fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture<E> {
match self {
Fallback::Default(route) | Fallback::Service(route) => {
RouteFuture::from_future(route.oneshot_inner(req))
}
Fallback::Default(route) | Fallback::Service(route) => route.oneshot_inner(req),
Fallback::BoxedHandler(handler) => {
let mut route = handler.clone().into_route(state);
RouteFuture::from_future(route.oneshot_inner(req))
route.oneshot_inner(req)
}
}
}
Expand Down
Loading