diff --git a/framework/src/executor/mod.rs b/framework/src/executor/mod.rs index acf1d70cb..790081056 100644 --- a/framework/src/executor/mod.rs +++ b/framework/src/executor/mod.rs @@ -4,19 +4,22 @@ mod tests; pub use factory::ServiceExecutorFactory; -use std::cell::RefCell; -use std::collections::HashMap; -use std::panic::{self, AssertUnwindSafe}; -use std::rc::Rc; -use std::sync::Arc; +use std::{ + cell::RefCell, + collections::HashMap, + ops::{Deref, DerefMut}, + panic::{self, AssertUnwindSafe}, + rc::Rc, + sync::Arc, +}; use cita_trie::DB as TrieDB; use derive_more::Display; use common_apm::muta_apm; use protocol::traits::{ - Context, Dispatcher, Executor, ExecutorParams, ExecutorResp, NoopDispatcher, ServiceMapping, - ServiceResponse, ServiceState, Storage, + Context, Dispatcher, Executor, ExecutorParams, ExecutorResp, NoopDispatcher, Service, + ServiceMapping, ServiceResponse, ServiceState, Storage, }; use protocol::types::{ Address, Hash, MerkleRoot, Receipt, ReceiptResponse, ServiceContext, ServiceContextParams, @@ -27,21 +30,118 @@ use protocol::{ProtocolError, ProtocolErrorKind, ProtocolResult}; use crate::binding::sdk::{DefaultChainQuerier, DefaultServiceSDK}; use crate::binding::state::{GeneralServiceState, MPTTrie}; +trait TxHooks { + fn before(&mut self, _: Context, _: ServiceContext) -> ProtocolResult<()> { + Ok(()) + } + + fn after(&mut self, _: Context, _: ServiceContext) -> ProtocolResult<()> { + Ok(()) + } +} + +impl TxHooks for () {} + enum HookType { Before, After, } -#[derive(Clone)] +#[derive(Clone, Copy)] enum ExecType { Read, Write, } +struct ServiceStateMap(HashMap>>>); + +impl ServiceStateMap { + fn new() -> ServiceStateMap { + Self(HashMap::new()) + } +} + +impl Deref for ServiceStateMap { + type Target = HashMap>>>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for ServiceStateMap { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl ServiceStateMap { + fn stash(&self) -> ProtocolResult<()> { + for state in self.0.values() { + state.borrow_mut().stash()?; + } + + Ok(()) + } + + fn revert_cache(&self) -> ProtocolResult<()> { + for state in self.0.values() { + state.borrow_mut().revert_cache()?; + } + + Ok(()) + } +} + +struct CommitHooks { + inner: Vec>, + states: Rc>, +} + +impl CommitHooks { + fn new(hooks: Vec>, states: Rc>) -> CommitHooks { + Self { + inner: hooks, + states, + } + } + + // bagua kan 101 :) + fn kan R, R>(states: Rc>, hook: H) -> ProtocolResult<()> { + match panic::catch_unwind(AssertUnwindSafe(hook)) { + Ok(_) => states.stash(), + Err(_) => states.revert_cache(), + } + } +} + +impl TxHooks for CommitHooks { + // TODO: support abort execution + fn before(&mut self, _context: Context, service_context: ServiceContext) -> ProtocolResult<()> { + for hook in self.inner.iter_mut() { + Self::kan(Rc::clone(&self.states), || { + hook.tx_hook_before_(service_context.clone()) + })?; + } + + Ok(()) + } + + fn after(&mut self, _context: Context, service_context: ServiceContext) -> ProtocolResult<()> { + for hook in self.inner.iter_mut() { + Self::kan(Rc::clone(&self.states), || { + hook.tx_hook_after_(service_context.clone()) + })?; + } + + Ok(()) + } +} + pub struct ServiceExecutor { service_mapping: Arc, querier: Rc>, - states: Rc>>>>, + states: Rc>, root_state: Rc>>, } @@ -67,7 +167,7 @@ impl ProtocolResult { let querier = Rc::new(DefaultChainQuerier::new(Arc::clone(&storage))); - let mut states = HashMap::new(); + let mut states = ServiceStateMap::new(); for name in mapping.list_service_name().into_iter() { let trie = MPTTrie::new(Arc::clone(&trie_db)); @@ -111,7 +211,7 @@ impl MPTTrie::from(service_root, Arc::clone(&trie_db))?, @@ -141,19 +241,11 @@ impl ProtocolResult<()> { - for state in self.states.values() { - state.borrow_mut().stash()?; - } - - Ok(()) + self.states.stash() } fn revert_cache(&mut self) -> ProtocolResult<()> { - for state in self.states.values() { - state.borrow_mut().revert_cache()?; - } - - Ok(()) + self.states.revert_cache() } #[muta_apm::derive::tracing_span( @@ -236,20 +328,44 @@ impl Box { + match exec_type { + ExecType::Read => Box::new(()), + ExecType::Write => { + let mut tx_hooks = vec![]; + + for name in self.service_mapping.list_service_name().into_iter() { + let sdk = self + .get_sdk(&name) + .unwrap_or_else(|e| panic!("get target service sdk failed: {}", e)); + + let tx_hook_service = self + .service_mapping + .get_service(name.as_str(), sdk) + .unwrap_or_else(|e| panic!("get target service sdk failed: {}", e)); + + tx_hooks.push(tx_hook_service); + } + + let hooks = CommitHooks::new(tx_hooks, Rc::clone(&self.states)); + Box::new(hooks) + } + } + } + fn catch_call( &mut self, - context: ServiceContext, + context: Context, + service_context: ServiceContext, exec_type: ExecType, ) -> ProtocolResult> { - let result = match exec_type { - ExecType::Read => panic::catch_unwind(AssertUnwindSafe(|| { - self.call(context.clone(), exec_type.clone()) - })), - ExecType::Write => panic::catch_unwind(AssertUnwindSafe(|| { - self.call_with_tx_hooks(context.clone(), exec_type.clone()) - })), - }; - match result { + let mut tx_hooks = self.get_tx_hooks(exec_type); + + tx_hooks.before(context.clone(), service_context.clone())?; + + let ret = match panic::catch_unwind(AssertUnwindSafe(|| { + self.call(service_context.clone(), exec_type) + })) { Ok(r) => { self.stash()?; Ok(r) @@ -259,38 +375,11 @@ impl ServiceResponse { - let mut tx_hook_services = vec![]; - for name in self.service_mapping.list_service_name().into_iter() { - let sdk = self - .get_sdk(&name) - .unwrap_or_else(|e| panic!("get target service sdk failed: {}", e)); - let tx_hook_service = self - .service_mapping - .get_service(name.as_str(), sdk) - .unwrap_or_else(|e| panic!("get target service sdk failed: {}", e)); - tx_hook_services.push(tx_hook_service); - } - // TODO: If tx_hook_before_ failed, we should not exec the tx. - // Need a mechanism for this. - for tx_hook_service in tx_hook_services.iter_mut() { - tx_hook_service.tx_hook_before_(context.clone()); - } - let original_res = self.call(context.clone(), exec_type); - // TODO: If the tx fails, status tx_hook_after_ changes will also be reverted. - // It may not be what the developer want. - // Need a new mechanism for this. - for tx_hook_service in tx_hook_services.iter_mut() { - tx_hook_service.tx_hook_after_(context.clone()); - } - original_res + tx_hooks.after(context, service_context)?; + + ret } fn call(&self, context: ServiceContext, exec_type: ExecType) -> ServiceResponse { @@ -325,7 +414,7 @@ impl { sdk: SDK, } -#[derive(Deserialize, Serialize, Clone, Debug)] -pub struct TestReadPayload { - pub key: String, -} - -#[derive(Deserialize, Serialize, Clone, Debug, Default)] -pub struct TestReadResponse { - pub value: String, -} - #[derive(Deserialize, Serialize, Clone, Debug)] pub struct TestWritePayload { pub key: String, @@ -36,14 +26,9 @@ impl TestService { #[cycles(100_00)] #[read] - fn test_read( - &self, - ctx: ServiceContext, - payload: TestReadPayload, - ) -> ServiceResponse { - let value: String = self.sdk.get_value(&payload.key).unwrap_or_default(); - let res = TestReadResponse { value }; - ServiceResponse::::from_succeed(res) + fn test_read(&self, ctx: ServiceContext, payload: String) -> ServiceResponse { + let value: String = self.sdk.get_value(&payload).unwrap_or_default(); + ServiceResponse::from_succeed(value) } #[cycles(210_00)] @@ -91,6 +76,40 @@ impl TestService { ServiceResponse::::from_succeed(TestWriteResponse {}) } + #[cycles(210_00)] + #[write] + fn test_panic(&mut self, ctx: ServiceContext, _payload: String) -> ServiceResponse<()> { + panic!("hello panic"); + } + + #[cycles(210_00)] + #[write] + fn tx_hook_before_panic( + &mut self, + ctx: ServiceContext, + _payload: String, + ) -> ServiceResponse<()> { + self.sdk.set_value( + "tx_hook_before_panic".to_owned(), + "tx_hook_before_panic".to_owned(), + ); + ServiceResponse::from_succeed(()) + } + + #[cycles(210_00)] + #[write] + fn tx_hook_after_panic( + &mut self, + ctx: ServiceContext, + _payload: String, + ) -> ServiceResponse<()> { + self.sdk.set_value( + "tx_hook_after_panic".to_owned(), + "tx_hook_after_panic".to_owned(), + ); + ServiceResponse::from_succeed(()) + } + #[tx_hook_before] fn test_tx_hook_before(&mut self, ctx: ServiceContext) { if ctx.get_service_name() == "test" @@ -98,6 +117,12 @@ impl TestService { { ctx.emit_event("test_tx_hook_before invoked".to_owned()); } + + if ctx.get_service_method() == "tx_hook_before_panic" { + panic!("tx hook before"); + } + + self.sdk.set_value("before".to_owned(), "before".to_owned()); } #[tx_hook_after] @@ -107,5 +132,11 @@ impl TestService { { ctx.emit_event("test_tx_hook_after invoked".to_owned()); } + + if ctx.get_service_method() == "tx_hook_after_panic" { + panic!("tx hook before"); + } + + self.sdk.set_value("after".to_owned(), "after".to_owned()); } }