-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
main.rs
159 lines (135 loc) · 4.95 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
//! Example chat application.
//!
//! Run with
//!
//! ```not_rust
//! cargo run -p example-chat
//! ```
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
State,
},
response::{Html, IntoResponse},
routing::get,
Router,
};
use futures::{sink::SinkExt, stream::StreamExt};
use std::{
collections::HashSet,
sync::{Arc, Mutex},
};
use tokio::sync::broadcast;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
// Our shared state
struct AppState {
// We require unique usernames. This tracks which usernames have been taken.
user_set: Mutex<HashSet<String>>,
// Channel used to send messages to all connected clients.
tx: broadcast::Sender<String>,
}
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "example_chat=trace".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
// Set up application state for use with with_state().
let user_set = Mutex::new(HashSet::new());
let (tx, _rx) = broadcast::channel(100);
let app_state = Arc::new(AppState { user_set, tx });
let app = Router::new()
.route("/", get(index))
.route("/websocket", get(websocket_handler))
.with_state(app_state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
tracing::debug!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}
async fn websocket_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<AppState>>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| websocket(socket, state))
}
// This function deals with a single websocket connection, i.e., a single
// connected client / user, for which we will spawn two independent tasks (for
// receiving / sending chat messages).
async fn websocket(stream: WebSocket, state: Arc<AppState>) {
// By splitting, we can send and receive at the same time.
let (mut sender, mut receiver) = stream.split();
// Username gets set in the receive loop, if it's valid.
let mut username = String::new();
// Loop until a text message is found.
while let Some(Ok(message)) = receiver.next().await {
if let Message::Text(name) = message {
// If username that is sent by client is not taken, fill username string.
check_username(&state, &mut username, &name);
// If not empty we want to quit the loop else we want to quit function.
if !username.is_empty() {
break;
} else {
// Only send our client that username is taken.
let _ = sender
.send(Message::Text(String::from("Username already taken.")))
.await;
return;
}
}
}
// We subscribe *before* sending the "joined" message, so that we will also
// display it to our client.
let mut rx = state.tx.subscribe();
// Now send the "joined" message to all subscribers.
let msg = format!("{username} joined.");
tracing::debug!("{msg}");
let _ = state.tx.send(msg);
// Spawn the first task that will receive broadcast messages and send text
// messages over the websocket to our client.
let mut send_task = tokio::spawn(async move {
while let Ok(msg) = rx.recv().await {
// In any websocket error, break loop.
if sender.send(Message::Text(msg)).await.is_err() {
break;
}
}
});
// Clone things we want to pass (move) to the receiving task.
let tx = state.tx.clone();
let name = username.clone();
// Spawn a task that takes messages from the websocket, prepends the user
// name, and sends them to all broadcast subscribers.
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(Message::Text(text))) = receiver.next().await {
// Add username before message.
let _ = tx.send(format!("{name}: {text}"));
}
});
// If any one of the tasks run to completion, we abort the other.
tokio::select! {
_ = (&mut send_task) => recv_task.abort(),
_ = (&mut recv_task) => send_task.abort(),
};
// Send "user left" message (similar to "joined" above).
let msg = format!("{username} left.");
tracing::debug!("{msg}");
let _ = state.tx.send(msg);
// Remove username from map so new clients can take it again.
state.user_set.lock().unwrap().remove(&username);
}
fn check_username(state: &AppState, string: &mut String, name: &str) {
let mut user_set = state.user_set.lock().unwrap();
if !user_set.contains(name) {
user_set.insert(name.to_owned());
string.push_str(name);
}
}
// Include utf-8 file at **compile** time.
async fn index() -> Html<&'static str> {
Html(std::include_str!("../chat.html"))
}