Skip to content

Commit

Permalink
HintWriter & HintReader
Browse files Browse the repository at this point in the history
  • Loading branch information
clabby committed Sep 23, 2023
1 parent 2d5bd28 commit c682fa7
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 36 deletions.
128 changes: 97 additions & 31 deletions crates/preimage/src/hints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,32 @@ pub struct HintWriter {
tx: Sender<Vec<u8>>,
}

unsafe impl Send for HintWriter {}
unsafe impl Sync for HintWriter {}

impl HintWriter {
fn new(rx: Receiver<Vec<u8>>, tx: Sender<Vec<u8>>) -> Self {
Self { rx, tx }
}
}

impl Hinter for HintWriter {
fn hint<T: Hint>(&self, value: T) -> Result<bool> {
fn hint<T: Hint>(&self, value: T) -> Result<()> {
let hint = value.hint();
let mut hint_bytes = vec![0u8; 4 + hint.len()];
hint_bytes[0..4].copy_from_slice((hint.len() as u32).to_be_bytes().as_ref());
hint_bytes[4..].copy_from_slice(hint);

self.tx.send(hint_bytes)?;

match self.rx.recv() {
Ok(n) => {
if n.len() != 1 {
anyhow::bail!(
"Failed to read invalid pre-image hint ack, received response: {:?}",
n
);
}
Ok(true)
}
Err(e) => Ok(false),
let n = self.rx.recv()?;
if n.len() != 1 {
anyhow::bail!(
"Failed to read invalid pre-image hint ack, received response: {:?}",
n
);
}
Ok(())
}
}

Expand All @@ -48,6 +47,9 @@ pub struct HintReader {
tx: Sender<Vec<u8>>,
}

unsafe impl Send for HintReader {}
unsafe impl Sync for HintReader {}

impl HintReader {
fn new(rx: Receiver<Vec<u8>>, tx: Sender<Vec<u8>>) -> Self {
Self { rx, tx }
Expand All @@ -56,15 +58,17 @@ impl HintReader {

impl HintReader {
pub fn next_hint(&self, router: HintHandler) -> Result<bool> {
let raw_len = self.rx.recv()?;
if raw_len.len() != 4 {
let raw_payload = self.rx.recv()?;
if raw_payload.len() < 4 {
// Return EOF
return Ok(true);
}
let length = u32::from_be_bytes(raw_len.as_slice().try_into()?) as usize;

let length = u32::from_be_bytes(raw_payload.as_slice()[0..4].try_into()?) as usize;
let payload = if length == 0 {
Vec::default()
} else {
self.rx.recv()?
raw_payload[4..].try_into()?
};

if let Err(e) = router(&payload) {
Expand All @@ -75,7 +79,7 @@ impl HintReader {

// write back to unblock the hint writer after routing the hint we received.
self.tx.send(vec![0])?;
Ok(true)
Ok(false)
}
}

Expand All @@ -84,37 +88,38 @@ mod tests {
use super::*;
use std::sync::{
atomic::{AtomicU32, Ordering},
Arc,
mpsc, Arc,
};

async fn test_hint(hints: Vec<Vec<u8>>) {
let (bw, ar) = std::sync::mpsc::channel::<Vec<u8>>();
let (aw, br) = std::sync::mpsc::channel::<Vec<u8>>();

let hint_writer = Arc::new(HintWriter::new(ar, aw));
let hint_reader = Arc::new(HintReader::new(br, bw));

let counter_written = Arc::new(AtomicU32::new(0));
let counter_received = Arc::new(AtomicU32::new(0));

let hints_a = Arc::new(hints.clone());
let counter_w = Arc::clone(&counter_written);
let (hints_a, counter_w) = (Arc::new(hints.clone()), Arc::clone(&counter_written));
let a = tokio::spawn(async move {
let hint_writer = HintWriter::new(ar, aw);
let cw = Arc::clone(&counter_w);
for hint in hints_a.iter() {
cw.fetch_add(1, Ordering::SeqCst);
counter_w.fetch_add(1, Ordering::SeqCst);
hint_writer.hint(hint).unwrap();
}
});

let hints_b = Arc::new(hints.clone());
let counter_r = Arc::clone(&counter_received);
let (reader, hints_b, counter_r) = (
Arc::clone(&hint_reader),
Arc::new(hints.clone()),
Arc::clone(&counter_received),
);
let b = tokio::spawn(async move {
let hint_reader = HintReader::new(br, bw);
for i in 0..hints_b.len() {
let counter = Arc::clone(&counter_r);
let Ok(eof) = hint_reader.next_hint(Box::new(move |hint| {
let counter_r = Arc::clone(&counter_r);
let Ok(eof) = reader.next_hint(Box::new(move |hint| {
// Increase the number of hint requests received.
counter.fetch_add(1, Ordering::SeqCst);
dbg!("yo");
counter_r.fetch_add(1, Ordering::SeqCst);
Ok(())
})) else {
panic!("Failed to read hint {}", i);
Expand All @@ -141,7 +146,68 @@ mod tests {

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn hello_world_hint() {
test_hint(vec![b"asd".to_vec()]).await;
test_hint(vec![b"hello world".to_vec()]).await;
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn zero_byte() {
test_hint(vec![vec![0]]).await;
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn many_zeros() {
test_hint(vec![vec![0; 1000]]).await;
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn rand_bytes() {
use rand::RngCore;

let mut rand = [0u8; 2048];
rand::thread_rng().fill_bytes(&mut rand);
test_hint(vec![rand.to_vec()]).await;
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn multiple_hints() {
test_hint(vec![
b"hello world".to_vec(),
b"cannon cannon cannon".to_vec(),
b"".to_vec(),
b"milady".to_vec(),
])
.await;
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn cb_error() {
let (aw, br) = mpsc::channel();
let (bw, ar) = mpsc::channel();

let hint_writer = Arc::new(HintWriter::new(ar, aw));
let hint_reader = Arc::new(HintReader::new(br, bw));

let writer = Arc::clone(&hint_writer);
let a = tokio::spawn(async move {
writer.hint(b"one".to_vec().as_ref()).unwrap();
writer.hint(b"two".to_vec().as_ref()).unwrap();
});

let reader = Arc::clone(&hint_reader);
let b = tokio::spawn(async move {
let Err(_) = reader.next_hint(Box::new(|hint| {
anyhow::bail!("cb_error");
})) else {
panic!("Failed to read hint");
};

reader
.next_hint(Box::new(|hint| {
assert_eq!(hint, b"two");
Ok(())
}))
.unwrap();
});
}

impl Hint for String {
Expand Down
10 changes: 5 additions & 5 deletions crates/preimage/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ pub trait Hint {
pub trait Hinter {
/// Sends a hint to the host.
///
/// ### Takes
/// - `hint` - The hint to send to the host.
///
/// ### Returns
/// - `Ok(true)` if the hint was sent successfully.
/// - `Ok(false)` if the hint was not sent successfully due to the host being
/// closed.
/// - `Err(e)` if the hint was not sent successfully due to an error.
fn hint<T: Hint>(&self, hint: T) -> Result<bool>;
/// - A [Result] indicating whether or not the hint was successfully sent.
fn hint<T: Hint>(&self, hint: T) -> Result<()>;
}

0 comments on commit c682fa7

Please sign in to comment.