Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an example integration test for SSE #2465

Merged
merged 7 commits into from
Dec 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/sse/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,8 @@ tokio-stream = "0.1"
tower-http = { version = "0.5.0", features = ["fs", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

[dev-dependencies]
eventsource-stream = "0.2"
reqwest = { version = "0.11", features = ["stream"] }
reqwest-eventsource = "0.5"
80 changes: 71 additions & 9 deletions examples/sse/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
//! ```not_rust
//! cargo run -p example-sse
//! ```
//! Test with
//! ```not_rust
//! cargo test -p example-sse
//! ```

use axum::{
response::sse::{Event, Sse},
Expand All @@ -26,15 +30,8 @@ async fn main() {
.with(tracing_subscriber::fmt::layer())
.init();

let assets_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets");

let static_files_service = ServeDir::new(assets_dir).append_index_html_on_directories(true);

// build our application with a route
let app = Router::new()
.fallback_service(static_files_service)
.route("/sse", get(sse_handler))
.layer(TraceLayer::new_for_http());
// build our application
let app = app();

// run it
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
Expand All @@ -44,6 +41,16 @@ async fn main() {
axum::serve(listener, app).await.unwrap();
}

fn app() -> Router {
let assets_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets");
let static_files_service = ServeDir::new(assets_dir).append_index_html_on_directories(true);
// build our application with a route
Router::new()
.fallback_service(static_files_service)
.route("/sse", get(sse_handler))
.layer(TraceLayer::new_for_http())
}

async fn sse_handler(
TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
Expand All @@ -63,3 +70,58 @@ async fn sse_handler(
.text("keep-alive-text"),
)
}

#[cfg(test)]
mod tests {
use eventsource_stream::Eventsource;
use tokio::net::TcpListener;

use super::*;

#[tokio::test]
async fn integration_test() {
// A helper function that spawns our application in the background
async fn spawn_app(host: impl Into<String>) -> String {
let host = host.into();
// Bind to localhost at the port 0, which will let the OS assign an available port to us
let listener = TcpListener::bind(format!("{}:0", host)).await.unwrap();
// Retrieve the port assigned to us by the OS
let port = listener.local_addr().unwrap().port();
tokio::spawn(async {
axum::serve(listener, app()).await.unwrap();
});
// Returns address (e.g. http://127.0.0.1{random_port})
format!("http://{}:{}", host, port)
}
let listening_url = spawn_app("127.0.0.1").await;

let mut event_stream = reqwest::Client::new()
.get(&format!("{}/sse", listening_url))
.header("User-Agent", "integration_test")
.send()
.await
.unwrap()
.bytes_stream()
.eventsource()
.take(1);

let mut event_data: Vec<String> = vec![];
while let Some(event) = event_stream.next().await {
match event {
Ok(event) => {
// break the loop at the end of SSE stream
if event.data == "[DONE]" {
break;
}

event_data.push(event.data);
}
Err(_) => {
panic!("Error in event stream");
}
}
}

assert!(event_data[0] == "hi!");
}
}