From b3b29859e741653697f940593cf743bd9c7b1e52 Mon Sep 17 00:00:00 2001 From: Zeke Mostov Date: Thu, 4 Aug 2022 15:29:01 -0400 Subject: [PATCH 1/9] Impl futures::Stream for Listener --- Cargo.lock | 29 +++++++++++++++++++++++++++++ qos-core/Cargo.toml | 1 + qos-core/src/io/stream.rs | 13 +++++++++++++ 3 files changed, 43 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 1c73f46f..d2824af6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -219,6 +219,20 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f73fe65f54d1e12b726f517d3e2135ca3125a437b6d998caf1962961f7172d9e" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.21" @@ -226,6 +240,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3083ce4b914124575708913bca19bfe887522d6e2e6d0952943f5eac4a74010" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -234,6 +249,18 @@ version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c09fd04b7e4073ac7156a9539b57a484a8ea920f79c7c675d05d289ab6110d3" +[[package]] +name = "futures-io" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc4045962a5a5e935ee2fdedaa4e08284547402885ab326734432bed5d12966b" + +[[package]] +name = "futures-sink" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21163e139fa306126e6eedaf49ecdb4588f939600f0b1e770f4205ee4b7fa868" + [[package]] name = "futures-task" version = "0.3.21" @@ -247,6 +274,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d8b7abd5d659d9b90c8cba917f6ec750a74e2dc23902ef9cd4cc8c8b22e6036a" dependencies = [ "futures-core", + "futures-sink", "futures-task", "pin-project-lite", "pin-utils", @@ -614,6 +642,7 @@ version = "0.1.0" dependencies = [ "aws-nitro-enclaves-nsm-api", "borsh", + "futures", "nix 0.24.1", "openssl", "qos-crypto", diff --git a/qos-core/Cargo.toml b/qos-core/Cargo.toml index 002ce47f..f111907b 100644 --- a/qos-core/Cargo.toml +++ b/qos-core/Cargo.toml @@ -9,6 +9,7 @@ qos-crypto = { path = "../qos-crypto" } nix = { version = "0.24.1", features = ["socket"], default-features = false } openssl = { version = "0.10.40", default-features = false } borsh = { version = "0.9" } +futures = { version = "0.3", default-features = false } # For AWS Nitro aws-nitro-enclaves-nsm-api = { version = "0.2.1", default-features = false } diff --git a/qos-core/src/io/stream.rs b/qos-core/src/io/stream.rs index 6765e1b9..5dca9ea4 100644 --- a/qos-core/src/io/stream.rs +++ b/qos-core/src/io/stream.rs @@ -12,6 +12,9 @@ use nix::{ }, unistd::close, }; +use core::task::Poll; +use core::task::Context; +use core::pin::Pin; use super::IOError; @@ -243,6 +246,16 @@ impl Iterator for Listener { } } +impl futures::stream::Stream for Listener { + type Item = Stream; + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_> + ) -> Poll> { + Poll::Ready(self.accept().ok()) + } +} + impl Drop for Listener { fn drop(&mut self) { // Its ok if either of these error - likely means the other end of the From fb71aed56dafacee845b0100ccf827b24d63626f Mon Sep 17 00:00:00 2001 From: Zeke Mostov Date: Thu, 4 Aug 2022 17:50:26 -0400 Subject: [PATCH 2/9] Make state sync and compile --- Cargo.lock | 26 +++++++++++ qos-core/Cargo.toml | 2 +- qos-core/src/client.rs | 2 +- qos-core/src/protocol/attestor/mock.rs | 1 + qos-core/src/protocol/attestor/mod.rs | 2 +- qos-core/src/protocol/mod.rs | 43 ++++++++++++------ qos-core/src/protocol/services/boot.rs | 2 +- qos-core/src/protocol/services/provision.rs | 11 ++--- qos-core/src/server.rs | 48 ++++++++++++++++++--- 9 files changed, 109 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d2824af6..19dff7aa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -227,6 +227,7 @@ checksum = "f73fe65f54d1e12b726f517d3e2135ca3125a437b6d998caf1962961f7172d9e" dependencies = [ "futures-channel", "futures-core", + "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -249,6 +250,18 @@ version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c09fd04b7e4073ac7156a9539b57a484a8ea920f79c7c675d05d289ab6110d3" +[[package]] +name = "futures-executor" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9420b90cfa29e327d0429f19be13e7ddb68fa1cccb09d65e5706b8c7a749b8a6" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", + "num_cpus", +] + [[package]] name = "futures-io" version = "0.3.21" @@ -273,11 +286,15 @@ version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d8b7abd5d659d9b90c8cba917f6ec750a74e2dc23902ef9cd4cc8c8b22e6036a" dependencies = [ + "futures-channel", "futures-core", + "futures-io", "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -807,6 +824,15 @@ dependencies = [ "serde", ] +[[package]] +name = "slab" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4614a76b2a8be0058caa9dbbaf66d988527d86d003c11a94fbd335d7661edcef" +dependencies = [ + "autocfg", +] + [[package]] name = "socket2" version = "0.4.4" diff --git a/qos-core/Cargo.toml b/qos-core/Cargo.toml index f111907b..1dc8ea32 100644 --- a/qos-core/Cargo.toml +++ b/qos-core/Cargo.toml @@ -9,7 +9,7 @@ qos-crypto = { path = "../qos-crypto" } nix = { version = "0.24.1", features = ["socket"], default-features = false } openssl = { version = "0.10.40", default-features = false } borsh = { version = "0.9" } -futures = { version = "0.3", default-features = false } +futures = { version = "0.3", default-features = false, features = ["thread-pool"] } # For AWS Nitro aws-nitro-enclaves-nsm-api = { version = "0.2.1", default-features = false } diff --git a/qos-core/src/client.rs b/qos-core/src/client.rs index 300c7afa..c803829e 100644 --- a/qos-core/src/client.rs +++ b/qos-core/src/client.rs @@ -24,7 +24,7 @@ impl From for ClientError { } /// Client for communicating with the enclave [`crate::server::SocketServer`]. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Client { addr: SocketAddress, } diff --git a/qos-core/src/protocol/attestor/mock.rs b/qos-core/src/protocol/attestor/mock.rs index 3b1de9ad..ecf87f78 100644 --- a/qos-core/src/protocol/attestor/mock.rs +++ b/qos-core/src/protocol/attestor/mock.rs @@ -31,6 +31,7 @@ pub const MOCK_NSM_ATTESTATION_DOCUMENT: &[u8] = include_bytes!("./static/mock_attestation_doc"); /// Mock Nitro Secure Module endpoint that should only ever be used for testing. +#[derive(Clone)] pub struct MockNsm; impl NsmProvider for MockNsm { fn nsm_process_request( diff --git a/qos-core/src/protocol/attestor/mod.rs b/qos-core/src/protocol/attestor/mod.rs index 08052992..c9530b45 100644 --- a/qos-core/src/protocol/attestor/mod.rs +++ b/qos-core/src/protocol/attestor/mod.rs @@ -11,7 +11,7 @@ pub mod types; /// generic so mock providers can be subbed in for testing. In production use /// [`Nsm`]. // https://github.com/aws/aws-nitro-enclaves-nsm-api/blob/main/docs/attestation_process.md -pub trait NsmProvider { +pub trait NsmProvider: Send { /// Create a message with input data and output capacity from a given /// request, then send it to the NSM driver via `ioctl()` and wait /// for the driver's response. diff --git a/qos-core/src/protocol/mod.rs b/qos-core/src/protocol/mod.rs index bf22118d..c240757f 100644 --- a/qos-core/src/protocol/mod.rs +++ b/qos-core/src/protocol/mod.rs @@ -1,5 +1,7 @@ //! Quorum protocol state machine. +use std::sync::{Arc, RwLock}; + use borsh::{BorshDeserialize, BorshSerialize}; use qos_crypto::sha_256; @@ -102,6 +104,7 @@ pub enum ProtocolError { /// Payload is too big. See `MAX_ENCODED_MSG_LEN` for the upper bound on /// message size. OversizedPayload, + InUnrecoverablePhase, } impl From for ProtocolError { @@ -130,7 +133,7 @@ impl From for ProtocolError { /// Protocol executor state. #[derive( - Debug, PartialEq, Clone, borsh::BorshSerialize, borsh::BorshDeserialize, + Debug, PartialEq, Copy, Clone, borsh::BorshSerialize, borsh::BorshDeserialize, )] pub enum ProtocolPhase { /// The state machine cannot recover. The enclave must be rebooted. @@ -143,13 +146,29 @@ pub enum ProtocolPhase { QuorumKeyProvisioned, } +impl ProtocolPhase { + /// Try to update the `current` phase to the `target` phase. If the current phase is not updatable, + /// the phase will not be updated. + /// + /// Returns a copy of the new current phase, which will match `target` if `current` was updated. + fn update(current: &Arc>, target: ProtocolPhase) -> Result<(), ProtocolError> { + let mut current = current.write().unwrap(); + if *current == ProtocolPhase::UnrecoverableError { + Err(ProtocolError::InUnrecoverablePhase) + } else { + *current = target; + Ok(()) + } + } +} + /// Enclave executor state // TODO only include mutables in here, all else should be written to file as // read only pub struct ProtocolState { - provisioner: services::provision::SecretBuilder, + provisioner: Arc>, attestor: Box, - phase: ProtocolPhase, + phase: Arc>, handles: Handles, app_client: Client, } @@ -163,8 +182,8 @@ impl ProtocolState { let provisioner = services::provision::SecretBuilder::new(); Self { attestor, - provisioner, - phase: ProtocolPhase::WaitingForBootInstruction, + provisioner: Arc::new(RwLock::new(provisioner)), + phase: Arc::new(RwLock::new(ProtocolPhase::WaitingForBootInstruction)), handles, app_client: Client::new(app_addr), } @@ -189,7 +208,7 @@ impl Executor { } fn routes(&self) -> Vec> { - match self.state.phase { + match *self.state.phase.read().unwrap() { ProtocolPhase::UnrecoverableError => { vec![Box::new(handlers::status)] } @@ -252,7 +271,7 @@ impl server::RequestProcessor for Executor { } } - let err = ProtocolError::NoMatchingRoute(self.state.phase.clone()); + let err = ProtocolError::NoMatchingRoute(*self.state.phase.read().unwrap()); ProtocolMsg::ProtocolErrorResponse(err) .try_to_vec() .expect("ProtocolMsg can always be serialized. qed.") @@ -274,7 +293,7 @@ mod handlers { state: &mut ProtocolState, ) -> Option { if let ProtocolMsg::StatusRequest = req { - Some(ProtocolMsg::StatusResponse(state.phase.clone())) + Some(ProtocolMsg::StatusResponse(*state.phase.read().unwrap())) } else { None } @@ -324,7 +343,7 @@ mod handlers { Some(ProtocolMsg::ProvisionResponse { reconstructed }) } Err(e) => { - state.phase = ProtocolPhase::UnrecoverableError; + *state.phase.write().unwrap() = ProtocolPhase::UnrecoverableError; Some(ProtocolMsg::ProtocolErrorResponse(e)) } } @@ -346,7 +365,7 @@ mod handlers { Some(ProtocolMsg::BootStandardResponse { nsm_response }) } Err(e) => { - state.phase = ProtocolPhase::UnrecoverableError; + *state.phase.write().unwrap() = ProtocolPhase::UnrecoverableError; Some(ProtocolMsg::ProtocolErrorResponse(e)) } } @@ -368,7 +387,7 @@ mod handlers { }) } Err(e) => { - state.phase = ProtocolPhase::UnrecoverableError; + *state.phase.write().unwrap() = ProtocolPhase::UnrecoverableError; Some(ProtocolMsg::ProtocolErrorResponse(e)) } } @@ -389,7 +408,7 @@ mod handlers { }) } Err(e) => { - state.phase = ProtocolPhase::UnrecoverableError; + *state.phase.write().unwrap() = ProtocolPhase::UnrecoverableError; Some(ProtocolMsg::ProtocolErrorResponse(e)) } } diff --git a/qos-core/src/protocol/services/boot.rs b/qos-core/src/protocol/services/boot.rs index 0333323a..f31a3bc5 100644 --- a/qos-core/src/protocol/services/boot.rs +++ b/qos-core/src/protocol/services/boot.rs @@ -212,7 +212,7 @@ pub(in crate::protocol) fn boot_standard( manifest_envelope.manifest.qos_hash().to_vec(), ); - state.phase = ProtocolPhase::WaitingForQuorumShards; + ProtocolPhase::update(&state.phase, ProtocolPhase::WaitingForQuorumShards)?; Ok(nsm_response) } diff --git a/qos-core/src/protocol/services/provision.rs b/qos-core/src/protocol/services/provision.rs index 5b9129dc..5fe24ca9 100644 --- a/qos-core/src/protocol/services/provision.rs +++ b/qos-core/src/protocol/services/provision.rs @@ -60,16 +60,17 @@ pub(in crate::protocol) fn provision( .envelope_decrypt(encrypted_share) .map_err(|_| ProtocolError::DecryptionFailed)?; - state.provisioner.add_share(share)?; + let mut provisioner = state.provisioner.write().unwrap(); + provisioner.add_share(share)?; let manifest = state.handles.get_manifest_envelope()?.manifest; let quorum_threshold = manifest.quorum_set.threshold as usize; - if state.provisioner.count() < quorum_threshold { + if provisioner.count() < quorum_threshold { // Nothing else to do if we don't have the threshold to reconstruct return Ok(false); } - let private_key_der = state.provisioner.build()?; + let private_key_der = provisioner.build()?; let pair = qos_crypto::RsaPair::from_der(&private_key_der) .map_err(|_| ProtocolError::InvalidPrivateKey)?; let public_key_der = pair.public_key_to_der()?; @@ -78,13 +79,13 @@ pub(in crate::protocol) fn provision( // We did not construct the intended key // Something went wrong, so clear the existing shares just to be // careful. - state.provisioner.clear(); + provisioner.clear(); return Err(ProtocolError::ReconstructionError); } state.handles.put_quorum_key(&pair)?; - state.phase = ProtocolPhase::QuorumKeyProvisioned; + ProtocolPhase::update(&state.phase, ProtocolPhase::QuorumKeyProvisioned)?; Ok(true) } diff --git a/qos-core/src/server.rs b/qos-core/src/server.rs index 72fec215..1563408a 100644 --- a/qos-core/src/server.rs +++ b/qos-core/src/server.rs @@ -1,7 +1,8 @@ //! Streaming socket based server for use in an enclave. Listens for connections //! from [`crate::client::Client`]. -use std::marker::PhantomData; +use std::{marker::PhantomData}; + use crate::io::{self, Listener, SocketAddress}; @@ -42,17 +43,50 @@ impl SocketServer { println!("`SocketServer` listening on {:?}", addr); let listener = Listener::listen(addr)?; + // let threads = Vec::new(); + // let proccesor_locked = Arc::new(Mutex::new(processor)); + + // futures::executor::block_on(listener.for_each_concurrent(None, move|stream| { + // match stream.recv() { + // Ok(payload) => { + // let response = proccesor_locked.clone().lock().unwrap().process(payload); + // // let _ = stream.send(&response); + // } + // Err(err) => eprintln!("Server::listen error: {:?}", err), + // } + // })); for stream in listener { - match stream.recv() { - Ok(payload) => { - let response = processor.process(payload); - let _ = stream.send(&response); + match stream.recv() { + Ok(payload) => { + let response = processor.process(payload); + let _ = stream.send(&response); + } + Err(err) => eprintln!("Server::listen error: {:?}", err), } - Err(err) => eprintln!("Server::listen error: {:?}", err), - } } + // for stream in listener { + // // TODO: wait if threads are maxed out + // let processor2 = processor.clone(); + // let thread = std::thread::spawn(move || { + // match stream.recv() { + // Ok(payload) => { + // let response = processor.process(payload); + // let _ = stream.send(&response); + // } + // Err(err) => eprintln!("Server::listen error: {:?}", err), + // } + // } + // ); + // threads.push(thread); + + // } + + // for thread in threads { + // drop(thread.join()); + // } + Ok(()) } } From 209ec2f74e0c03a23ba642661a860ca79a641dab Mon Sep 17 00:00:00 2001 From: Zeke Mostov Date: Fri, 5 Aug 2022 11:26:35 -0400 Subject: [PATCH 3/9] Initial thread pool impl --- qos-core/src/coordinator.rs | 4 +- qos-core/src/io/mod.rs | 1 + qos-core/src/io/stream.rs | 9 +-- qos-core/src/io/threadpool.rs | 127 ++++++++++++++++++++++++++++++++++ qos-core/src/protocol/mod.rs | 38 +++++++--- qos-core/src/server.rs | 70 ++++++++----------- 6 files changed, 189 insertions(+), 60 deletions(-) create mode 100644 qos-core/src/io/threadpool.rs diff --git a/qos-core/src/coordinator.rs b/qos-core/src/coordinator.rs index 7a993c43..e7bea5e9 100644 --- a/qos-core/src/coordinator.rs +++ b/qos-core/src/coordinator.rs @@ -4,7 +4,7 @@ //! //! The pivot is an executable the enclave runs to initialize the secure //! applications. -use std::process::Command; +use std::{process::Command, sync::Arc}; use crate::{ handles::Handles, @@ -36,7 +36,7 @@ impl Coordinator { let handles2 = handles.clone(); std::thread::spawn(move || { let executor = Executor::new(nsm, handles2, app_addr); - SocketServer::listen(addr, executor).unwrap(); + SocketServer::listen(addr, Arc::new(executor)).unwrap(); }); loop { diff --git a/qos-core/src/io/mod.rs b/qos-core/src/io/mod.rs index e810e0b7..8c501979 100644 --- a/qos-core/src/io/mod.rs +++ b/qos-core/src/io/mod.rs @@ -4,6 +4,7 @@ //! within this module. mod stream; +pub mod threadpool; pub use stream::SocketAddress; pub(crate) use stream::{Listener, Stream}; diff --git a/qos-core/src/io/stream.rs b/qos-core/src/io/stream.rs index 5dca9ea4..6e074c7c 100644 --- a/qos-core/src/io/stream.rs +++ b/qos-core/src/io/stream.rs @@ -1,5 +1,9 @@ //! Abstractions to handle connection based socket streams. +use core::{ + pin::Pin, + task::{Context, Poll}, +}; use std::{mem::size_of, os::unix::io::RawFd}; #[cfg(feature = "vm")] @@ -12,9 +16,6 @@ use nix::{ }, unistd::close, }; -use core::task::Poll; -use core::task::Context; -use core::pin::Pin; use super::IOError; @@ -250,7 +251,7 @@ impl futures::stream::Stream for Listener { type Item = Stream; fn poll_next( self: Pin<&mut Self>, - _cx: &mut Context<'_> + _cx: &mut Context<'_>, ) -> Poll> { Poll::Ready(self.accept().ok()) } diff --git a/qos-core/src/io/threadpool.rs b/qos-core/src/io/threadpool.rs new file mode 100644 index 00000000..b656f621 --- /dev/null +++ b/qos-core/src/io/threadpool.rs @@ -0,0 +1,127 @@ +use std::{ + sync::{mpsc, Arc, Mutex}, + thread, +}; + +type Job = Box; + +/// Errors for a [`ThreadPool`] +pub enum ThreadPoolError { + MpscSendError(std::sync::mpsc::SendError), +} + +/// An abstraction for executing jobs concurrently across a fixed number of +/// threads. +pub struct ThreadPool { + workers: Vec, + sender: mpsc::Sender, +} + +/// Message sent to a worker thread in the thread pool. +pub enum Message { + NewJob(Job), + Terminate, +} + +impl ThreadPool { + /// Create a new instance of [`Self`]. + /// + /// # Arguments + /// + /// * `size` - Number of threads in pool. + /// + /// # Panics + /// + /// Panics if the `size` is zero. + pub fn new(size: usize) -> ThreadPool { + assert!(size > 0); + + let (sender, receiver) = mpsc::channel(); + + let receiver = Arc::new(Mutex::new(receiver)); + + let mut workers = Vec::with_capacity(size); + + for id in 0..size { + workers.push(Worker::new(id, Arc::clone(&receiver))); + } + + ThreadPool { workers, sender } + } + + /// Execute `f` in the next free thread. This is non blocking. + /// + /// # Errors + /// + /// Returns an error if the `f` could not be sent to a worker thread. + pub fn execute(&self, f: F) -> Result<(), ThreadPoolError> + where + F: FnOnce() + Send + 'static, + { + let job = Box::new(f); + + self.sender + .send(Message::NewJob(job)) + .map_err(|e| ThreadPoolError::MpscSendError(e)); + Ok(()) + } +} + +impl Drop for ThreadPool { + fn drop(&mut self) { + // Send 1 termination signal per worker thread. We don't know exactly + // which worker will recieve each message, but since we know that a + // worker will stop receiving after getting the terminate message, we + // can be confident that non-terminated threads will recieve the + // terminate message exactly once and terminated threads will never + // receive the message. Thus, if we have N workers and send N terminate + // messages we will terminate all worker threads. + for _ in &self.workers { + drop( + self.sender + .send(Message::Terminate) + .map_err(|e| eprintln!("`ThreadPool::drop`: {:?}", e)), + ); + } + + for worker in &mut self.workers { + if let Some(thread) = worker.thread.take() { + drop(thread.join().map_err(|e| { + eprintln!("`ThreadPool::drop: failed to join: {:?}`", e) + })) + } + } + } +} + +struct Worker { + id: usize, + thread: Option>, +} + +impl Worker { + fn new(id: usize, receiver: Arc>>) -> Worker { + let thread = thread::spawn(move || loop { + let message = receiver + .lock() + .expect("channel receiver mutex poisoned") + .recv() + .expect("tried to receive on a closed chanel"); + + match message { + Message::NewJob(job) => { + println!("Worker {} got a job; executing.", id); + + job(); + } + Message::Terminate => { + println!("Worker {} was told to terminate.", id); + + break; + } + } + }); + + Worker { id, thread: Some(thread) } + } +} diff --git a/qos-core/src/protocol/mod.rs b/qos-core/src/protocol/mod.rs index c240757f..281e6576 100644 --- a/qos-core/src/protocol/mod.rs +++ b/qos-core/src/protocol/mod.rs @@ -133,7 +133,12 @@ impl From for ProtocolError { /// Protocol executor state. #[derive( - Debug, PartialEq, Copy, Clone, borsh::BorshSerialize, borsh::BorshDeserialize, + Debug, + PartialEq, + Copy, + Clone, + borsh::BorshSerialize, + borsh::BorshDeserialize, )] pub enum ProtocolPhase { /// The state machine cannot recover. The enclave must be rebooted. @@ -147,11 +152,15 @@ pub enum ProtocolPhase { } impl ProtocolPhase { - /// Try to update the `current` phase to the `target` phase. If the current phase is not updatable, - /// the phase will not be updated. + /// Try to update the `current` phase to the `target` phase. If the current + /// phase is not updatable, the phase will not be updated. /// - /// Returns a copy of the new current phase, which will match `target` if `current` was updated. - fn update(current: &Arc>, target: ProtocolPhase) -> Result<(), ProtocolError> { + /// Returns a copy of the new current phase, which will match `target` if + /// `current` was updated. + fn update( + current: &Arc>, + target: ProtocolPhase, + ) -> Result<(), ProtocolError> { let mut current = current.write().unwrap(); if *current == ProtocolPhase::UnrecoverableError { Err(ProtocolError::InUnrecoverablePhase) @@ -183,7 +192,9 @@ impl ProtocolState { Self { attestor, provisioner: Arc::new(RwLock::new(provisioner)), - phase: Arc::new(RwLock::new(ProtocolPhase::WaitingForBootInstruction)), + phase: Arc::new(RwLock::new( + ProtocolPhase::WaitingForBootInstruction, + )), handles, app_client: Client::new(app_addr), } @@ -271,7 +282,8 @@ impl server::RequestProcessor for Executor { } } - let err = ProtocolError::NoMatchingRoute(*self.state.phase.read().unwrap()); + let err = + ProtocolError::NoMatchingRoute(*self.state.phase.read().unwrap()); ProtocolMsg::ProtocolErrorResponse(err) .try_to_vec() .expect("ProtocolMsg can always be serialized. qed.") @@ -343,7 +355,8 @@ mod handlers { Some(ProtocolMsg::ProvisionResponse { reconstructed }) } Err(e) => { - *state.phase.write().unwrap() = ProtocolPhase::UnrecoverableError; + *state.phase.write().unwrap() = + ProtocolPhase::UnrecoverableError; Some(ProtocolMsg::ProtocolErrorResponse(e)) } } @@ -365,7 +378,8 @@ mod handlers { Some(ProtocolMsg::BootStandardResponse { nsm_response }) } Err(e) => { - *state.phase.write().unwrap() = ProtocolPhase::UnrecoverableError; + *state.phase.write().unwrap() = + ProtocolPhase::UnrecoverableError; Some(ProtocolMsg::ProtocolErrorResponse(e)) } } @@ -387,7 +401,8 @@ mod handlers { }) } Err(e) => { - *state.phase.write().unwrap() = ProtocolPhase::UnrecoverableError; + *state.phase.write().unwrap() = + ProtocolPhase::UnrecoverableError; Some(ProtocolMsg::ProtocolErrorResponse(e)) } } @@ -408,7 +423,8 @@ mod handlers { }) } Err(e) => { - *state.phase.write().unwrap() = ProtocolPhase::UnrecoverableError; + *state.phase.write().unwrap() = + ProtocolPhase::UnrecoverableError; Some(ProtocolMsg::ProtocolErrorResponse(e)) } } diff --git a/qos-core/src/server.rs b/qos-core/src/server.rs index 1563408a..8dc24ae9 100644 --- a/qos-core/src/server.rs +++ b/qos-core/src/server.rs @@ -1,10 +1,11 @@ //! Streaming socket based server for use in an enclave. Listens for connections //! from [`crate::client::Client`]. -use std::{marker::PhantomData}; +use std::{marker::PhantomData, sync::Arc}; +use crate::io::{self, threadpool::ThreadPool, Listener, SocketAddress}; -use crate::io::{self, Listener, SocketAddress}; +const DEFAULT_THREAD_COUNT: usize = 4; /// Error variants for [`SocketServer`] #[derive(Debug)] @@ -38,55 +39,38 @@ impl SocketServer { /// Listen and respond to incoming requests with the given `processor`. pub fn listen( addr: SocketAddress, - mut processor: R, + mut processor: Arc, + thread_count: Option, ) -> Result<(), SocketServerError> { - println!("`SocketServer` listening on {:?}", addr); + let thread_count = thread_count.unwrap_or(DEFAULT_THREAD_COUNT); + println!( + "`SocketServer` listening on {:?} with thread count {thread_count}", + addr + ); let listener = Listener::listen(addr)?; - // let threads = Vec::new(); - // let proccesor_locked = Arc::new(Mutex::new(processor)); - - // futures::executor::block_on(listener.for_each_concurrent(None, move|stream| { - // match stream.recv() { - // Ok(payload) => { - // let response = proccesor_locked.clone().lock().unwrap().process(payload); - // // let _ = stream.send(&response); - // } - // Err(err) => eprintln!("Server::listen error: {:?}", err), - // } - // })); + let thread_pool = ThreadPool::new(thread_count); for stream in listener { - match stream.recv() { - Ok(payload) => { - let response = processor.process(payload); - let _ = stream.send(&response); - } - Err(err) => eprintln!("Server::listen error: {:?}", err), - } - } + let processor2 = processor.clone(); - // for stream in listener { - // // TODO: wait if threads are maxed out - // let processor2 = processor.clone(); - // let thread = std::thread::spawn(move || { - // match stream.recv() { - // Ok(payload) => { - // let response = processor.process(payload); - // let _ = stream.send(&response); - // } - // Err(err) => eprintln!("Server::listen error: {:?}", err), - // } - // } - // ); - // threads.push(thread); + let result = thread_pool + .execute(move || Self::handle_stream(processor2, stream)) + .map_err(|e| eprintln!("`SocketServer::listen`: {:?}", e)); - // } - - // for thread in threads { - // drop(thread.join()); - // } + drop(result) + } Ok(()) } + + fn handle_stream(processor: Arc, stream: io::Stream) { + match stream.recv() { + Ok(payload) => { + let response = processor.process(payload); + let _ = stream.send(&response); + } + Err(err) => eprintln!("Server::listen error: {:?}", err), + } + } } From 58cd54bc4c008bb9d57fe5aff7cf129d28b11c86 Mon Sep 17 00:00:00 2001 From: Zeke Mostov Date: Fri, 5 Aug 2022 12:09:41 -0400 Subject: [PATCH 4/9] Initial async friendly processor / executor --- Cargo.lock | 55 ------------------- qos-core/Cargo.toml | 1 - qos-core/src/coordinator.rs | 6 +- qos-core/src/io/stream.rs | 15 ----- qos-core/src/io/threadpool.rs | 37 +++++++------ qos-core/src/protocol/attestor/mod.rs | 2 +- qos-core/src/protocol/mod.rs | 8 ++- qos-core/src/protocol/services/attestation.rs | 2 +- qos-core/src/protocol/services/boot.rs | 7 ++- qos-core/src/protocol/services/provision.rs | 41 ++++++++++---- qos-core/src/server.rs | 15 ++--- sample-app/src/cli.rs | 2 +- sample-app/src/lib.rs | 1 + 13 files changed, 76 insertions(+), 116 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 19dff7aa..1c73f46f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -219,21 +219,6 @@ dependencies = [ "percent-encoding", ] -[[package]] -name = "futures" -version = "0.3.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f73fe65f54d1e12b726f517d3e2135ca3125a437b6d998caf1962961f7172d9e" -dependencies = [ - "futures-channel", - "futures-core", - "futures-executor", - "futures-io", - "futures-sink", - "futures-task", - "futures-util", -] - [[package]] name = "futures-channel" version = "0.3.21" @@ -241,7 +226,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3083ce4b914124575708913bca19bfe887522d6e2e6d0952943f5eac4a74010" dependencies = [ "futures-core", - "futures-sink", ] [[package]] @@ -250,30 +234,6 @@ version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c09fd04b7e4073ac7156a9539b57a484a8ea920f79c7c675d05d289ab6110d3" -[[package]] -name = "futures-executor" -version = "0.3.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9420b90cfa29e327d0429f19be13e7ddb68fa1cccb09d65e5706b8c7a749b8a6" -dependencies = [ - "futures-core", - "futures-task", - "futures-util", - "num_cpus", -] - -[[package]] -name = "futures-io" -version = "0.3.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc4045962a5a5e935ee2fdedaa4e08284547402885ab326734432bed5d12966b" - -[[package]] -name = "futures-sink" -version = "0.3.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21163e139fa306126e6eedaf49ecdb4588f939600f0b1e770f4205ee4b7fa868" - [[package]] name = "futures-task" version = "0.3.21" @@ -286,15 +246,10 @@ version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d8b7abd5d659d9b90c8cba917f6ec750a74e2dc23902ef9cd4cc8c8b22e6036a" dependencies = [ - "futures-channel", "futures-core", - "futures-io", - "futures-sink", "futures-task", - "memchr", "pin-project-lite", "pin-utils", - "slab", ] [[package]] @@ -659,7 +614,6 @@ version = "0.1.0" dependencies = [ "aws-nitro-enclaves-nsm-api", "borsh", - "futures", "nix 0.24.1", "openssl", "qos-crypto", @@ -824,15 +778,6 @@ dependencies = [ "serde", ] -[[package]] -name = "slab" -version = "0.4.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4614a76b2a8be0058caa9dbbaf66d988527d86d003c11a94fbd335d7661edcef" -dependencies = [ - "autocfg", -] - [[package]] name = "socket2" version = "0.4.4" diff --git a/qos-core/Cargo.toml b/qos-core/Cargo.toml index 1dc8ea32..002ce47f 100644 --- a/qos-core/Cargo.toml +++ b/qos-core/Cargo.toml @@ -9,7 +9,6 @@ qos-crypto = { path = "../qos-crypto" } nix = { version = "0.24.1", features = ["socket"], default-features = false } openssl = { version = "0.10.40", default-features = false } borsh = { version = "0.9" } -futures = { version = "0.3", default-features = false, features = ["thread-pool"] } # For AWS Nitro aws-nitro-enclaves-nsm-api = { version = "0.2.1", default-features = false } diff --git a/qos-core/src/coordinator.rs b/qos-core/src/coordinator.rs index e7bea5e9..07252837 100644 --- a/qos-core/src/coordinator.rs +++ b/qos-core/src/coordinator.rs @@ -4,7 +4,7 @@ //! //! The pivot is an executable the enclave runs to initialize the secure //! applications. -use std::{process::Command, sync::Arc}; +use std::process::Command; use crate::{ handles::Handles, @@ -29,14 +29,14 @@ impl Coordinator { /// - If waiting for the pivot errors. pub fn execute( handles: &Handles, - nsm: Box, + nsm: Box, addr: SocketAddress, app_addr: SocketAddress, ) { let handles2 = handles.clone(); std::thread::spawn(move || { let executor = Executor::new(nsm, handles2, app_addr); - SocketServer::listen(addr, Arc::new(executor)).unwrap(); + SocketServer::listen(addr, executor, None).unwrap(); }); loop { diff --git a/qos-core/src/io/stream.rs b/qos-core/src/io/stream.rs index 6e074c7c..ccc87c94 100644 --- a/qos-core/src/io/stream.rs +++ b/qos-core/src/io/stream.rs @@ -1,9 +1,4 @@ //! Abstractions to handle connection based socket streams. - -use core::{ - pin::Pin, - task::{Context, Poll}, -}; use std::{mem::size_of, os::unix::io::RawFd}; #[cfg(feature = "vm")] @@ -247,16 +242,6 @@ impl Iterator for Listener { } } -impl futures::stream::Stream for Listener { - type Item = Stream; - fn poll_next( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(self.accept().ok()) - } -} - impl Drop for Listener { fn drop(&mut self) { // Its ok if either of these error - likely means the other end of the diff --git a/qos-core/src/io/threadpool.rs b/qos-core/src/io/threadpool.rs index b656f621..074806db 100644 --- a/qos-core/src/io/threadpool.rs +++ b/qos-core/src/io/threadpool.rs @@ -1,3 +1,5 @@ +//! Simple thread pool for running concurrent jobs on separate threads. + use std::{ sync::{mpsc, Arc, Mutex}, thread, @@ -6,7 +8,9 @@ use std::{ type Job = Box; /// Errors for a [`ThreadPool`] +#[derive(Debug)] pub enum ThreadPoolError { + /// Wrapper for `std::sync::mpsc::SendError`. MpscSendError(std::sync::mpsc::SendError), } @@ -19,7 +23,9 @@ pub struct ThreadPool { /// Message sent to a worker thread in the thread pool. pub enum Message { + /// Start a new Job. NewJob(Job), + /// Terminate the thread. Terminate, } @@ -33,6 +39,7 @@ impl ThreadPool { /// # Panics /// /// Panics if the `size` is zero. + #[must_use] pub fn new(size: usize) -> ThreadPool { assert!(size > 0); @@ -42,8 +49,8 @@ impl ThreadPool { let mut workers = Vec::with_capacity(size); - for id in 0..size { - workers.push(Worker::new(id, Arc::clone(&receiver))); + for _ in 0..size { + workers.push(Worker::new(Arc::clone(&receiver))); } ThreadPool { workers, sender } @@ -62,7 +69,7 @@ impl ThreadPool { self.sender .send(Message::NewJob(job)) - .map_err(|e| ThreadPoolError::MpscSendError(e)); + .map_err(ThreadPoolError::MpscSendError)?; Ok(()) } } @@ -77,30 +84,28 @@ impl Drop for ThreadPool { // receive the message. Thus, if we have N workers and send N terminate // messages we will terminate all worker threads. for _ in &self.workers { - drop( - self.sender - .send(Message::Terminate) - .map_err(|e| eprintln!("`ThreadPool::drop`: {:?}", e)), - ); + let _ = self + .sender + .send(Message::Terminate) + .map_err(|e| eprintln!("`ThreadPool::drop`: {:?}", e)); } for worker in &mut self.workers { if let Some(thread) = worker.thread.take() { - drop(thread.join().map_err(|e| { - eprintln!("`ThreadPool::drop: failed to join: {:?}`", e) - })) + let _ = thread.join().map_err(|e| { + eprintln!("`ThreadPool::drop: failed to join: {:?}`", e); + }); } } } } struct Worker { - id: usize, thread: Option>, } impl Worker { - fn new(id: usize, receiver: Arc>>) -> Worker { + fn new(receiver: Arc>>) -> Worker { let thread = thread::spawn(move || loop { let message = receiver .lock() @@ -110,18 +115,14 @@ impl Worker { match message { Message::NewJob(job) => { - println!("Worker {} got a job; executing.", id); - job(); } Message::Terminate => { - println!("Worker {} was told to terminate.", id); - break; } } }); - Worker { id, thread: Some(thread) } + Worker { thread: Some(thread) } } } diff --git a/qos-core/src/protocol/attestor/mod.rs b/qos-core/src/protocol/attestor/mod.rs index c9530b45..590d65f4 100644 --- a/qos-core/src/protocol/attestor/mod.rs +++ b/qos-core/src/protocol/attestor/mod.rs @@ -11,7 +11,7 @@ pub mod types; /// generic so mock providers can be subbed in for testing. In production use /// [`Nsm`]. // https://github.com/aws/aws-nitro-enclaves-nsm-api/blob/main/docs/attestation_process.md -pub trait NsmProvider: Send { +pub trait NsmProvider: Send + Sync { /// Create a message with input data and output capacity from a given /// request, then send it to the NSM driver via `ioctl()` and wait /// for the driver's response. diff --git a/qos-core/src/protocol/mod.rs b/qos-core/src/protocol/mod.rs index 281e6576..027d712b 100644 --- a/qos-core/src/protocol/mod.rs +++ b/qos-core/src/protocol/mod.rs @@ -104,6 +104,8 @@ pub enum ProtocolError { /// Payload is too big. See `MAX_ENCODED_MSG_LEN` for the upper bound on /// message size. OversizedPayload, + /// The enclave is in an unrecoverable phase and could not complete the + /// request. InUnrecoverablePhase, } @@ -174,9 +176,10 @@ impl ProtocolPhase { /// Enclave executor state // TODO only include mutables in here, all else should be written to file as // read only +#[derive(Clone)] pub struct ProtocolState { provisioner: Arc>, - attestor: Box, + attestor: Arc>, phase: Arc>, handles: Handles, app_client: Client, @@ -190,7 +193,7 @@ impl ProtocolState { ) -> Self { let provisioner = services::provision::SecretBuilder::new(); Self { - attestor, + attestor: Arc::new(attestor), provisioner: Arc::new(RwLock::new(provisioner)), phase: Arc::new(RwLock::new( ProtocolPhase::WaitingForBootInstruction, @@ -203,6 +206,7 @@ impl ProtocolState { /// Maybe rename state machine? /// Enclave state machine that executes when given a `ProtocolMsg`. +#[derive(Clone)] pub struct Executor { state: ProtocolState, } diff --git a/qos-core/src/protocol/services/attestation.rs b/qos-core/src/protocol/services/attestation.rs index 67605cc7..42d1e5fa 100644 --- a/qos-core/src/protocol/services/attestation.rs +++ b/qos-core/src/protocol/services/attestation.rs @@ -15,7 +15,7 @@ pub(in crate::protocol) fn live_attestation_doc( state.handles.get_manifest_envelope()?.manifest.qos_hash().to_vec(); Ok(get_post_boot_attestation_doc( - &*state.attestor, + &**state.attestor, ephemeral_public_key, manifest_hash, )) diff --git a/qos-core/src/protocol/services/boot.rs b/qos-core/src/protocol/services/boot.rs index f31a3bc5..ded1c96f 100644 --- a/qos-core/src/protocol/services/boot.rs +++ b/qos-core/src/protocol/services/boot.rs @@ -207,7 +207,7 @@ pub(in crate::protocol) fn boot_standard( state.handles.put_manifest_envelope(manifest_envelope)?; let nsm_response = attestation::get_post_boot_attestation_doc( - &*state.attestor, + &**state.attestor, ephemeral_key.public_key_to_pem()?, manifest_envelope.manifest.qos_hash().to_vec(), ); @@ -332,7 +332,10 @@ mod test { std::fs::remove_file(ephemeral_file).unwrap(); std::fs::remove_file(manifest_file).unwrap(); - assert_eq!(protocol_state.phase, ProtocolPhase::WaitingForQuorumShards); + assert_eq!( + *protocol_state.phase.read().unwrap(), + ProtocolPhase::WaitingForQuorumShards + ); } #[test] diff --git a/qos-core/src/protocol/services/provision.rs b/qos-core/src/protocol/services/provision.rs index 5fe24ca9..ae3e36d6 100644 --- a/qos-core/src/protocol/services/provision.rs +++ b/qos-core/src/protocol/services/provision.rs @@ -91,7 +91,10 @@ pub(in crate::protocol) fn provision( #[cfg(test)] mod test { - use std::path::Path; + use std::{ + path::Path, + sync::{Arc, RwLock}, + }; use qos_crypto::{sha_256, shamir::shares_generate, RsaPair}; @@ -156,9 +159,9 @@ mod test { // 3) Create state with eph key and manifest let state = ProtocolState { - provisioner: provision::SecretBuilder::new(), - attestor: Box::new(MockNsm), - phase: ProtocolPhase::WaitingForQuorumShards, + provisioner: Arc::new(RwLock::new(provision::SecretBuilder::new())), + attestor: Arc::new(Box::new(MockNsm)), + phase: Arc::new(RwLock::new(ProtocolPhase::WaitingForQuorumShards)), handles, app_client: Client::new(SocketAddress::new_unix("./never.sock")), }; @@ -188,7 +191,10 @@ mod test { for share in &encrypted_shares[..threshold - 1] { assert_eq!(provision(share, &mut state), Ok(false)); assert!(!Path::new(quorum_file).exists()); - assert_eq!(state.phase, ProtocolPhase::WaitingForQuorumShards); + assert_eq!( + *state.phase.read().unwrap(), + ProtocolPhase::WaitingForQuorumShards + ); } // 6) For shard K, call provision, make sure returns true and writes @@ -197,7 +203,10 @@ mod test { assert_eq!(provision(share, &mut state), Ok(true)); let quorum_key = std::fs::read(quorum_file).unwrap(); assert_eq!(quorum_key, quorum_pair.private_key_to_pem().unwrap()); - assert_eq!(state.phase, ProtocolPhase::QuorumKeyProvisioned); + assert_eq!( + *state.phase.read().unwrap(), + ProtocolPhase::QuorumKeyProvisioned + ); std::fs::remove_file(eph_file).unwrap(); std::fs::remove_file(quorum_file).unwrap(); @@ -227,7 +236,10 @@ mod test { for share in &encrypted_shares[..threshold - 1] { assert_eq!(provision(share, &mut state), Ok(false)); assert!(!Path::new(quorum_file).exists()); - assert_eq!(state.phase, ProtocolPhase::WaitingForQuorumShards); + assert_eq!( + *state.phase.read().unwrap(), + ProtocolPhase::WaitingForQuorumShards + ); } // 6) Add Kth shard of the random key @@ -238,7 +250,10 @@ mod test { ); assert!(!Path::new(quorum_file).exists()); // Note that the handler should set the state to unrecoverable error - assert_eq!(state.phase, ProtocolPhase::WaitingForQuorumShards); + assert_eq!( + *state.phase.read().unwrap(), + ProtocolPhase::WaitingForQuorumShards + ); std::fs::remove_file(eph_file).unwrap(); std::fs::remove_file(manifest_file).unwrap(); @@ -267,7 +282,10 @@ mod test { for share in &encrypted_shares[..threshold - 1] { assert_eq!(provision(share, &mut state), Ok(false)); assert!(!Path::new(quorum_file).exists()); - assert_eq!(state.phase, ProtocolPhase::WaitingForQuorumShards); + assert_eq!( + *state.phase.read().unwrap(), + ProtocolPhase::WaitingForQuorumShards + ); } // 6) Add a bogus shard as the Kth shard @@ -280,7 +298,10 @@ mod test { ); assert!(!Path::new(quorum_file).exists()); // Note that the handler should set the state to unrecoverable error - assert_eq!(state.phase, ProtocolPhase::WaitingForQuorumShards); + assert_eq!( + *state.phase.read().unwrap(), + ProtocolPhase::WaitingForQuorumShards + ); std::fs::remove_file(eph_file).unwrap(); std::fs::remove_file(manifest_file).unwrap(); diff --git a/qos-core/src/server.rs b/qos-core/src/server.rs index 8dc24ae9..ff3e2934 100644 --- a/qos-core/src/server.rs +++ b/qos-core/src/server.rs @@ -1,7 +1,7 @@ //! Streaming socket based server for use in an enclave. Listens for connections //! from [`crate::client::Client`]. -use std::{marker::PhantomData, sync::Arc}; +use std::marker::PhantomData; use crate::io::{self, threadpool::ThreadPool, Listener, SocketAddress}; @@ -35,11 +35,12 @@ pub struct SocketServer { _phantom: PhantomData, } -impl SocketServer { +impl SocketServer { /// Listen and respond to incoming requests with the given `processor`. + #[allow(clippy::needless_pass_by_value)] pub fn listen( addr: SocketAddress, - mut processor: Arc, + processor: R, thread_count: Option, ) -> Result<(), SocketServerError> { let thread_count = thread_count.unwrap_or(DEFAULT_THREAD_COUNT); @@ -54,17 +55,15 @@ impl SocketServer { for stream in listener { let processor2 = processor.clone(); - let result = thread_pool + let _ = thread_pool .execute(move || Self::handle_stream(processor2, stream)) .map_err(|e| eprintln!("`SocketServer::listen`: {:?}", e)); - - drop(result) } Ok(()) } - fn handle_stream(processor: Arc, stream: io::Stream) { + fn handle_stream(mut processor: R, stream: io::Stream) { match stream.recv() { Ok(payload) => { let response = processor.process(payload); @@ -72,5 +71,7 @@ impl SocketServer { } Err(err) => eprintln!("Server::listen error: {:?}", err), } + drop(processor); + drop(stream); } } diff --git a/sample-app/src/cli.rs b/sample-app/src/cli.rs index d06a8e06..541c2ac2 100644 --- a/sample-app/src/cli.rs +++ b/sample-app/src/cli.rs @@ -122,7 +122,7 @@ impl Cli { )); println!("---- Starting secure app server -----"); - SocketServer::listen(opts.addr(), processor).unwrap(); + SocketServer::listen(opts.addr(), processor, Some(4)).unwrap(); } } } diff --git a/sample-app/src/lib.rs b/sample-app/src/lib.rs index f6d25845..145bd715 100644 --- a/sample-app/src/lib.rs +++ b/sample-app/src/lib.rs @@ -87,6 +87,7 @@ pub enum AppMsg { } /// Request router for the app. +#[derive(Clone)] pub struct AppProcessor { handles: Handles, } From 07aeea68de3b53968d8ff437ae8abbfa70e77d77 Mon Sep 17 00:00:00 2001 From: Zeke Mostov Date: Fri, 5 Aug 2022 12:23:12 -0400 Subject: [PATCH 5/9] Add thread count option to enclave server --- qos-core/src/cli.rs | 16 +++++++++++++++- qos-core/src/coordinator.rs | 3 ++- qos-test/tests/coordinator.rs | 3 +++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/qos-core/src/cli.rs b/qos-core/src/cli.rs index f873f307..76e94d4e 100644 --- a/qos-core/src/cli.rs +++ b/qos-core/src/cli.rs @@ -17,7 +17,6 @@ pub const CID: &str = "cid"; pub const PORT: &str = "port"; /// "usock" pub const USOCK: &str = "usock"; -const MOCK: &str = "mock"; /// Name for the option to specify the quorum key file. pub const QUORUM_FILE_OPT: &str = "quorum-file"; /// Name for the option to specify the pivot key file. @@ -26,7 +25,11 @@ pub const PIVOT_FILE_OPT: &str = "pivot-file"; pub const EPHEMERAL_FILE_OPT: &str = "ephemeral-file"; /// Name for the option to specify the manifest file. pub const MANIFEST_FILE_OPT: &str = "manifest-file"; +/// Name for the option to specify the number of threads for the socket server's +/// thread pool. +pub const THREAD_COUNT: &str = "thread-count"; const APP_USOCK: &str = "app-usock"; +const MOCK: &str = "mock"; /// CLI options for starting up the enclave server. #[derive(Default, Clone, Debug, PartialEq)] @@ -118,6 +121,12 @@ impl EnclaveOpts { .expect("has a default value.") .clone() } + + fn thread_count(&self) -> Option { + self.parsed + .single(THREAD_COUNT) + .map(|n| n.parse().expect("failed to parse `--thread-count`")) + } } /// Enclave server CLI. @@ -143,6 +152,7 @@ impl CLI { opts.nsm(), opts.addr(), opts.app_addr(), + opts.thread_count(), ); } } @@ -198,6 +208,10 @@ impl GetParserForOptions for EnclaveParser { .takes_value(true) .default_value(SEC_APP_SOCK) ) + .token( + Token::new(THREAD_COUNT, "count of threads for the socket servers thread pool") + .takes_value(true) + ) } } diff --git a/qos-core/src/coordinator.rs b/qos-core/src/coordinator.rs index 07252837..2621c357 100644 --- a/qos-core/src/coordinator.rs +++ b/qos-core/src/coordinator.rs @@ -32,11 +32,12 @@ impl Coordinator { nsm: Box, addr: SocketAddress, app_addr: SocketAddress, + thread_count: Option, ) { let handles2 = handles.clone(); std::thread::spawn(move || { let executor = Executor::new(nsm, handles2, app_addr); - SocketServer::listen(addr, executor, None).unwrap(); + SocketServer::listen(addr, executor, thread_count).unwrap(); }); loop { diff --git a/qos-test/tests/coordinator.rs b/qos-test/tests/coordinator.rs index 2bb11735..18d8af89 100644 --- a/qos-test/tests/coordinator.rs +++ b/qos-test/tests/coordinator.rs @@ -41,6 +41,7 @@ fn coordinator_works() { Box::new(MockNsm), SocketAddress::new_unix(usock), SocketAddress::new_unix("./never.sock"), + None, ) }); @@ -94,6 +95,7 @@ fn coordinator_handles_non_zero_exits() { Box::new(MockNsm), SocketAddress::new_unix(usock), SocketAddress::new_unix("./never.sock"), + None, ) }); @@ -147,6 +149,7 @@ fn coordinator_handles_panic() { Box::new(MockNsm), SocketAddress::new_unix(usock), SocketAddress::new_unix("./never.sock"), + None, ) }); From b28247df237b3d04576dbff1659fa43b3cdfb32d Mon Sep 17 00:00:00 2001 From: Zeke Mostov Date: Fri, 5 Aug 2022 12:41:04 -0400 Subject: [PATCH 6/9] Add test for threadpool --- qos-core/src/io/threadpool.rs | 38 +++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/qos-core/src/io/threadpool.rs b/qos-core/src/io/threadpool.rs index 074806db..4de66eec 100644 --- a/qos-core/src/io/threadpool.rs +++ b/qos-core/src/io/threadpool.rs @@ -126,3 +126,41 @@ impl Worker { Worker { thread: Some(thread) } } } + +#[cfg(test)] +mod test { + use std::{ + collections::HashMap, + sync::{Arc, Mutex}, + }; + + use super::ThreadPool; + + #[test] + fn graceful_shutdown_works() { + const KEY: &str = "key"; + const EXECUTIONS: usize = 500; + + let mut db = HashMap::new(); + db.insert(KEY, 0); + + let db = Arc::new(Mutex::new(db)); + + // create job that + let thread_pool = ThreadPool::new(128); + + for _ in 0..EXECUTIONS { + let db2 = db.clone(); + thread_pool + .execute(move || { + *db2.lock().unwrap().get_mut(KEY).unwrap() += 1; + }) + .unwrap(); + } + + // Graceful shutdown + drop(thread_pool); + + assert_eq!(*db.lock().unwrap().get(KEY).unwrap(), EXECUTIONS); + } +} From 3692036154fdcb56a91adda7fba0dbc9ffff9c21 Mon Sep 17 00:00:00 2001 From: Zeke Mostov Date: Fri, 5 Aug 2022 15:53:42 -0400 Subject: [PATCH 7/9] Update qos-core/src/cli.rs --- qos-core/src/cli.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qos-core/src/cli.rs b/qos-core/src/cli.rs index 76e94d4e..059aaffb 100644 --- a/qos-core/src/cli.rs +++ b/qos-core/src/cli.rs @@ -209,7 +209,7 @@ impl GetParserForOptions for EnclaveParser { .default_value(SEC_APP_SOCK) ) .token( - Token::new(THREAD_COUNT, "count of threads for the socket servers thread pool") + Token::new(THREAD_COUNT, "count of threads for the socket server's thread pool") .takes_value(true) ) } From 8b4735ef7f45c96987db751357d234ded1873c35 Mon Sep 17 00:00:00 2001 From: Zeke Mostov Date: Fri, 5 Aug 2022 15:53:47 -0400 Subject: [PATCH 8/9] Update qos-core/src/io/threadpool.rs --- qos-core/src/io/threadpool.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/qos-core/src/io/threadpool.rs b/qos-core/src/io/threadpool.rs index 4de66eec..65d80d8a 100644 --- a/qos-core/src/io/threadpool.rs +++ b/qos-core/src/io/threadpool.rs @@ -146,7 +146,6 @@ mod test { let db = Arc::new(Mutex::new(db)); - // create job that let thread_pool = ThreadPool::new(128); for _ in 0..EXECUTIONS { From 311c39333ba5d1d6f0ed86e341555ae1d35b563b Mon Sep 17 00:00:00 2001 From: Zeke Mostov Date: Fri, 5 Aug 2022 17:52:21 -0400 Subject: [PATCH 9/9] Update doc comments --- qos-core/src/server.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/qos-core/src/server.rs b/qos-core/src/server.rs index ff3e2934..d8d5f96d 100644 --- a/qos-core/src/server.rs +++ b/qos-core/src/server.rs @@ -37,6 +37,13 @@ pub struct SocketServer { impl SocketServer { /// Listen and respond to incoming requests with the given `processor`. + /// + /// # Note Importantly + /// + /// The `processor` must afford the ability to be cloned and passed to a new + /// thread. For every new request, the `processor` will be cloned and passed + /// to a new thread, so if it has any state that state needs to be thread + /// safe after being cloned. #[allow(clippy::needless_pass_by_value)] pub fn listen( addr: SocketAddress, @@ -50,8 +57,8 @@ impl SocketServer { ); let listener = Listener::listen(addr)?; - let thread_pool = ThreadPool::new(thread_count); + for stream in listener { let processor2 = processor.clone(); @@ -63,6 +70,7 @@ impl SocketServer { Ok(()) } + #[allow(clippy::needless_pass_by_value)] fn handle_stream(mut processor: R, stream: io::Stream) { match stream.recv() { Ok(payload) => { @@ -71,7 +79,5 @@ impl SocketServer { } Err(err) => eprintln!("Server::listen error: {:?}", err), } - drop(processor); - drop(stream); } }