diff --git a/crates/preimage/src/hints.rs b/crates/preimage/src/hints.rs index c67509e..ceee05f 100644 --- a/crates/preimage/src/hints.rs +++ b/crates/preimage/src/hints.rs @@ -11,6 +11,9 @@ pub struct HintWriter { tx: Sender>, } +unsafe impl Send for HintWriter {} +unsafe impl Sync for HintWriter {} + impl HintWriter { fn new(rx: Receiver>, tx: Sender>) -> Self { Self { rx, tx } @@ -18,7 +21,7 @@ impl HintWriter { } impl Hinter for HintWriter { - fn hint(&self, value: T) -> Result { + fn hint(&self, value: T) -> Result<()> { let hint = value.hint(); let mut hint_bytes = vec![0u8; 4 + hint.len()]; hint_bytes[0..4].copy_from_slice((hint.len() as u32).to_be_bytes().as_ref()); @@ -26,18 +29,14 @@ impl Hinter for HintWriter { self.tx.send(hint_bytes)?; - match self.rx.recv() { - Ok(n) => { - if n.len() != 1 { - anyhow::bail!( - "Failed to read invalid pre-image hint ack, received response: {:?}", - n - ); - } - Ok(true) - } - Err(e) => Ok(false), + let n = self.rx.recv()?; + if n.len() != 1 { + anyhow::bail!( + "Failed to read invalid pre-image hint ack, received response: {:?}", + n + ); } + Ok(()) } } @@ -48,6 +47,9 @@ pub struct HintReader { tx: Sender>, } +unsafe impl Send for HintReader {} +unsafe impl Sync for HintReader {} + impl HintReader { fn new(rx: Receiver>, tx: Sender>) -> Self { Self { rx, tx } @@ -56,15 +58,17 @@ impl HintReader { impl HintReader { pub fn next_hint(&self, router: HintHandler) -> Result { - let raw_len = self.rx.recv()?; - if raw_len.len() != 4 { + let raw_payload = self.rx.recv()?; + if raw_payload.len() < 4 { + // Return EOF return Ok(true); } - let length = u32::from_be_bytes(raw_len.as_slice().try_into()?) as usize; + + let length = u32::from_be_bytes(raw_payload.as_slice()[0..4].try_into()?) as usize; let payload = if length == 0 { Vec::default() } else { - self.rx.recv()? + raw_payload[4..].try_into()? }; if let Err(e) = router(&payload) { @@ -75,7 +79,7 @@ impl HintReader { // write back to unblock the hint writer after routing the hint we received. self.tx.send(vec![0])?; - Ok(true) + Ok(false) } } @@ -84,37 +88,38 @@ mod tests { use super::*; use std::sync::{ atomic::{AtomicU32, Ordering}, - Arc, + mpsc, Arc, }; async fn test_hint(hints: Vec>) { let (bw, ar) = std::sync::mpsc::channel::>(); let (aw, br) = std::sync::mpsc::channel::>(); + let hint_writer = Arc::new(HintWriter::new(ar, aw)); + let hint_reader = Arc::new(HintReader::new(br, bw)); + let counter_written = Arc::new(AtomicU32::new(0)); let counter_received = Arc::new(AtomicU32::new(0)); - let hints_a = Arc::new(hints.clone()); - let counter_w = Arc::clone(&counter_written); + let (hints_a, counter_w) = (Arc::new(hints.clone()), Arc::clone(&counter_written)); let a = tokio::spawn(async move { - let hint_writer = HintWriter::new(ar, aw); - let cw = Arc::clone(&counter_w); for hint in hints_a.iter() { - cw.fetch_add(1, Ordering::SeqCst); + counter_w.fetch_add(1, Ordering::SeqCst); hint_writer.hint(hint).unwrap(); } }); - let hints_b = Arc::new(hints.clone()); - let counter_r = Arc::clone(&counter_received); + let (reader, hints_b, counter_r) = ( + Arc::clone(&hint_reader), + Arc::new(hints.clone()), + Arc::clone(&counter_received), + ); let b = tokio::spawn(async move { - let hint_reader = HintReader::new(br, bw); for i in 0..hints_b.len() { - let counter = Arc::clone(&counter_r); - let Ok(eof) = hint_reader.next_hint(Box::new(move |hint| { + let counter_r = Arc::clone(&counter_r); + let Ok(eof) = reader.next_hint(Box::new(move |hint| { // Increase the number of hint requests received. - counter.fetch_add(1, Ordering::SeqCst); - dbg!("yo"); + counter_r.fetch_add(1, Ordering::SeqCst); Ok(()) })) else { panic!("Failed to read hint {}", i); @@ -141,7 +146,68 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn hello_world_hint() { - test_hint(vec![b"asd".to_vec()]).await; + test_hint(vec![b"hello world".to_vec()]).await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn zero_byte() { + test_hint(vec![vec![0]]).await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn many_zeros() { + test_hint(vec![vec![0; 1000]]).await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn rand_bytes() { + use rand::RngCore; + + let mut rand = [0u8; 2048]; + rand::thread_rng().fill_bytes(&mut rand); + test_hint(vec![rand.to_vec()]).await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn multiple_hints() { + test_hint(vec![ + b"hello world".to_vec(), + b"cannon cannon cannon".to_vec(), + b"".to_vec(), + b"milady".to_vec(), + ]) + .await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn cb_error() { + let (aw, br) = mpsc::channel(); + let (bw, ar) = mpsc::channel(); + + let hint_writer = Arc::new(HintWriter::new(ar, aw)); + let hint_reader = Arc::new(HintReader::new(br, bw)); + + let writer = Arc::clone(&hint_writer); + let a = tokio::spawn(async move { + writer.hint(b"one".to_vec().as_ref()).unwrap(); + writer.hint(b"two".to_vec().as_ref()).unwrap(); + }); + + let reader = Arc::clone(&hint_reader); + let b = tokio::spawn(async move { + let Err(_) = reader.next_hint(Box::new(|hint| { + anyhow::bail!("cb_error"); + })) else { + panic!("Failed to read hint"); + }; + + reader + .next_hint(Box::new(|hint| { + assert_eq!(hint, b"two"); + Ok(()) + })) + .unwrap(); + }); } impl Hint for String { diff --git a/crates/preimage/src/traits.rs b/crates/preimage/src/traits.rs index f538405..dfac1a9 100644 --- a/crates/preimage/src/traits.rs +++ b/crates/preimage/src/traits.rs @@ -31,10 +31,10 @@ pub trait Hint { pub trait Hinter { /// Sends a hint to the host. /// + /// ### Takes + /// - `hint` - The hint to send to the host. + /// /// ### Returns - /// - `Ok(true)` if the hint was sent successfully. - /// - `Ok(false)` if the hint was not sent successfully due to the host being - /// closed. - /// - `Err(e)` if the hint was not sent successfully due to an error. - fn hint(&self, hint: T) -> Result; + /// - A [Result] indicating whether or not the hint was successfully sent. + fn hint(&self, hint: T) -> Result<()>; }