diff --git a/tokio-postgres/src/copy_both.rs b/tokio-postgres/src/copy_both.rs index d3b46eab..e8c2103b 100644 --- a/tokio-postgres/src/copy_both.rs +++ b/tokio-postgres/src/copy_both.rs @@ -16,6 +16,9 @@ use std::task::{Context, Poll}; /// The state machine of CopyBothReceiver /// /// ```ignore +/// Setup +/// | +/// v /// CopyBoth /// / \ /// v v @@ -189,17 +192,27 @@ impl Stream for CopyBothReceiver { Poll::Pending => match self.state { Setup | CopyBoth | CopyIn => match ready!(self.sink_receiver.poll_next_unpin(cx)) { Some(msg) => Poll::Ready(Some(msg)), - None => { - self.state = match self.state { - CopyBoth => CopyOut, - CopyIn => CopyNone, - _ => unreachable!(), - }; - - let mut buf = BytesMut::new(); - frontend::copy_done(&mut buf); - Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) - } + None => match self.state { + // The user has cancelled their interest to this CopyBoth query but we're + // still in the Setup phase. From this point the receiver will either enter + // CopyBoth mode or will receive an Error response from PostgreSQL. When + // either of those happens the state machine will terminate the connection + // appropriately. + Setup => Poll::Pending, + CopyBoth => { + self.state = CopyOut; + let mut buf = BytesMut::new(); + frontend::copy_done(&mut buf); + Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) + } + CopyIn => { + self.state = CopyNone; + let mut buf = BytesMut::new(); + frontend::copy_done(&mut buf); + Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) + } + _ => unreachable!(), + }, }, _ => Poll::Pending, }, diff --git a/tokio-postgres/tests/test/copy_both.rs b/tokio-postgres/tests/test/copy_both.rs index 2723928a..445d93c0 100644 --- a/tokio-postgres/tests/test/copy_both.rs +++ b/tokio-postgres/tests/test/copy_both.rs @@ -1,6 +1,8 @@ use futures_util::{future, StreamExt, TryStreamExt}; use tokio_postgres::{error::SqlState, Client, SimpleQueryMessage, SimpleQueryRow}; +use crate::Cancellable; + async fn q(client: &Client, query: &str) -> Vec { let msgs = client.simple_query(query).await.unwrap(); @@ -123,3 +125,36 @@ async fn copy_both() { // Ensure we can continue issuing queries assert_eq!(q(&client, "SELECT 1").await[0].get(0), Some("1")); } + +#[tokio::test] +async fn copy_both_future_cancellation() { + let client = crate::connect("user=postgres replication=database").await; + + let slot_query = + "CREATE_REPLICATION_SLOT future_cancellation TEMPORARY LOGICAL \"test_decoding\""; + let lsn = q(&client, slot_query).await[0] + .get("consistent_point") + .unwrap() + .to_owned(); + + let query = format!("START_REPLICATION SLOT future_cancellation LOGICAL {}", lsn); + for i in 0.. { + let done = { + let duplex_stream = client.copy_both_simple::(&query); + let fut = Cancellable { + fut: duplex_stream, + polls_left: i, + }; + fut.await + .map(|res| res.expect("copy_both failed")) + .is_some() + }; + + // Ensure we can continue issuing queries + assert_eq!(q(&client, "SELECT 1").await[0].get(0), Some("1")); + + if done { + break; + } + } +}