Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add impl Sink for WebSocketConnection #9

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ async-std = "1.7.0"
[dev-dependencies]
env_logger = "0.8.2"
async-std = { version = "1.7.0", features = ["attributes"] }
futures = "0.3.12"
191 changes: 191 additions & 0 deletions examples/futures_select.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
use async_std::channel::{unbounded, Sender};
use futures::{SinkExt, StreamExt};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tide::http::mime;
use tide::log::*;
use tide::Response;
use tide_websockets::{Message, WebSocket};

struct State {
connections: HashMap<u32, Sender<Message>>,
tick_count: u32,
next_id: u32,
}
impl State {
fn add_connection(&mut self, connection_tx: Sender<Message>) {
self.connections.insert(self.next_id, connection_tx);
self.next_id += 1;
}
}
#[async_std::main]
async fn main() -> Result<(), std::io::Error> {
env_logger::init();

let state = Arc::new(Mutex::new(State {
connections: HashMap::new(),
tick_count: 0,
next_id: 0,
}));

let mut app = tide::with_state(Arc::clone(&state));

async_std::task::spawn(app_loop(state));

app.at("/").get(|_| async move {
Ok(Response::builder(200)
.body(INDEX_HTML)
.content_type(mime::HTML)
.build())
});

app.at("/ws").get(WebSocket::new(
|request: tide::Request<Arc<Mutex<State>>>, stream| async move {
info!("new websocket opened");
let state = request.state();

// Channel for outgoing WebSocket messages from other threads
let (send_ws_msg_tx, send_ws_msg_rx) = unbounded::<Message>();
let mut send_ws_msg_rx = send_ws_msg_rx.fuse();

state.lock().unwrap().add_connection(send_ws_msg_tx);

let mut stream = stream.fuse();
loop {
let ws_msg = futures::select! {

ws_msg = stream.select_next_some() => {
match ws_msg {
// _ => None,
Ok(Message::Close(_)) => {
println!("peer disconnected");
break
},
Ok(Message::Ping(data)) => Some(Message::Pong(data)),
Ok(Message::Pong(_)) => None,
Ok(Message::Binary(_)) => None,
Ok(Message::Text(text)) => {
println!("Message received: {}",text);
None
},
Err(_) =>{
// done
None
}
}
},

// Handle WebSocket messages we created asynchronously
// to send them out now
ws_msg = send_ws_msg_rx.select_next_some() => Some(ws_msg),

// Once we're done, break the loop and return
complete => break,

};

// If there's a message to send out, do so now
if let Some(ws_msg) = ws_msg {
stream.send(ws_msg).await?;
}
}

// TODO: socket closed or errored, delete from remote handles

Ok(())
},
));

app.listen("127.0.0.1:8080").await?;

Ok(())
}

async fn app_loop(state: Arc<Mutex<State>>) -> () {
loop {
let (tick_count, connections) = {
let mut state = state.lock().unwrap();
println!(
"Tick #{}. Connection count is {}",
state.tick_count,
state.connections.len()
);
let tick_count = state.tick_count;
state.tick_count += 1;

let connections = state.connections.clone();
(tick_count, connections)
};
let connection_count = connections.len();

for conn in connections {
match conn
.1
.send(Message::Text(format!(
r#"{{ "ticks": {}, "connections": {} }}"#,
tick_count, connection_count
)))
.await
{
Ok(_) => (),
Err(_e) => {
//println!("error: {}", e);
state.lock().unwrap().connections.remove(&conn.0);
}
}
}
std::thread::sleep(std::time::Duration::from_millis(1000));
}
}

const INDEX_HTML: &str = r##"
<!DOCTYPE html>

<html>

<head>
<title>futures::select!() example</title>
<script type="text/javascript">
var count;
var tick;
var websocketConnection;

function onServerMessage(event) {
var msg;

try {
msg = JSON.parse(event.data);
} catch (e) {
console.log("bad: " + e);
return;
}

count.innerText = msg.connections;
tick.innerText = msg.ticks;

websocketConnection.send("Hello there.");
}

window.onload = function () {
count = document.getElementById("count");
tick = document.getElementById("tick");
var wsHost = window.location.hostname;
var wsPort = window.location.port;
if (wsPort)
wsPort = ":" + wsPort;
var wsUrl = "ws://" + wsHost + wsPort + "/ws";
console.log("ws url:" + wsUrl);

websocketConnection = new WebSocket(wsUrl);
websocketConnection.addEventListener("message", onServerMessage);
};

</script>
</head>

<body>
connection count: <span id="count">?</span> tick: <span id="tick">?</span>
</body>

</html>
"##;
34 changes: 33 additions & 1 deletion src/websocket_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use std::pin::Pin;
use async_dup::{Arc, Mutex};
use async_std::task;
use async_tungstenite::WebSocketStream;
use futures_util::stream::{SplitSink, SplitStream, Stream};
use futures_util::{
sink::Sink,
stream::{SplitSink, SplitStream, Stream},
};
use futures_util::{SinkExt, StreamExt};

use crate::Message;
Expand Down Expand Up @@ -55,6 +58,35 @@ impl Stream for WebSocketConnection {
}
}

impl Sink<Message> for WebSocketConnection {
type Error = async_tungstenite::tungstenite::Error;

fn poll_ready(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Result<(), Self::Error>> {
Pin::new(&mut *self.0.lock()).poll_ready(cx)
}

fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
Pin::new(&mut *self.0.lock()).start_send(item)
}

fn poll_flush(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Result<(), Self::Error>> {
Pin::new(&mut *self.0.lock()).poll_flush(cx)
}

fn poll_close(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Result<(), Self::Error>> {
Pin::new(&mut *self.0.lock()).poll_close(cx)
}
}

impl From<WebSocketStream<Connection>> for WebSocketConnection {
fn from(ws: WebSocketStream<Connection>) -> Self {
Self::new(ws)
Expand Down