Skip to content
This repository has been archived by the owner on Sep 13, 2022. It is now read-only.

feat(executor): indenpendent tx hook states commit #316

Merged
merged 2 commits into from
Jun 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 159 additions & 69 deletions framework/src/executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,22 @@ mod tests;

pub use factory::ServiceExecutorFactory;

use std::cell::RefCell;
use std::collections::HashMap;
use std::panic::{self, AssertUnwindSafe};
use std::rc::Rc;
use std::sync::Arc;
use std::{
cell::RefCell,
collections::HashMap,
ops::{Deref, DerefMut},
panic::{self, AssertUnwindSafe},
rc::Rc,
sync::Arc,
};

use cita_trie::DB as TrieDB;
use derive_more::Display;

use common_apm::muta_apm;
use protocol::traits::{
Context, Dispatcher, Executor, ExecutorParams, ExecutorResp, NoopDispatcher, ServiceMapping,
ServiceResponse, ServiceState, Storage,
Context, Dispatcher, Executor, ExecutorParams, ExecutorResp, NoopDispatcher, Service,
ServiceMapping, ServiceResponse, ServiceState, Storage,
};
use protocol::types::{
Address, Hash, MerkleRoot, Receipt, ReceiptResponse, ServiceContext, ServiceContextParams,
Expand All @@ -27,21 +30,118 @@ use protocol::{ProtocolError, ProtocolErrorKind, ProtocolResult};
use crate::binding::sdk::{DefaultChainQuerier, DefaultServiceSDK};
use crate::binding::state::{GeneralServiceState, MPTTrie};

trait TxHooks {
fn before(&mut self, _: Context, _: ServiceContext) -> ProtocolResult<()> {
Ok(())
}

fn after(&mut self, _: Context, _: ServiceContext) -> ProtocolResult<()> {
Ok(())
}
}

impl TxHooks for () {}

enum HookType {
Before,
After,
}

#[derive(Clone)]
#[derive(Clone, Copy)]
enum ExecType {
Read,
Write,
}

struct ServiceStateMap<DB: TrieDB>(HashMap<String, Rc<RefCell<GeneralServiceState<DB>>>>);

impl<DB: TrieDB> ServiceStateMap<DB> {
fn new() -> ServiceStateMap<DB> {
Self(HashMap::new())
}
}

impl<DB: TrieDB> Deref for ServiceStateMap<DB> {
type Target = HashMap<String, Rc<RefCell<GeneralServiceState<DB>>>>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<DB: TrieDB> DerefMut for ServiceStateMap<DB> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

impl<DB: TrieDB> ServiceStateMap<DB> {
fn stash(&self) -> ProtocolResult<()> {
for state in self.0.values() {
state.borrow_mut().stash()?;
}

Ok(())
}

fn revert_cache(&self) -> ProtocolResult<()> {
for state in self.0.values() {
state.borrow_mut().revert_cache()?;
}

Ok(())
}
}

struct CommitHooks<DB: TrieDB> {
inner: Vec<Box<dyn Service>>,
states: Rc<ServiceStateMap<DB>>,
}

impl<DB: TrieDB> CommitHooks<DB> {
fn new(hooks: Vec<Box<dyn Service>>, states: Rc<ServiceStateMap<DB>>) -> CommitHooks<DB> {
Self {
inner: hooks,
states,
}
}

// bagua kan 101 :)
fn kan<H: FnOnce() -> R, R>(states: Rc<ServiceStateMap<DB>>, hook: H) -> ProtocolResult<()> {
match panic::catch_unwind(AssertUnwindSafe(hook)) {
Ok(_) => states.stash(),
Err(_) => states.revert_cache(),
}
}
}

impl<DB: TrieDB> TxHooks for CommitHooks<DB> {
// TODO: support abort execution
fn before(&mut self, _context: Context, service_context: ServiceContext) -> ProtocolResult<()> {
for hook in self.inner.iter_mut() {
Self::kan(Rc::clone(&self.states), || {
hook.tx_hook_before_(service_context.clone())
})?;
}

Ok(())
}

fn after(&mut self, _context: Context, service_context: ServiceContext) -> ProtocolResult<()> {
for hook in self.inner.iter_mut() {
Self::kan(Rc::clone(&self.states), || {
hook.tx_hook_after_(service_context.clone())
})?;
}

Ok(())
}
}

pub struct ServiceExecutor<S: Storage, DB: TrieDB, Mapping: ServiceMapping> {
service_mapping: Arc<Mapping>,
querier: Rc<DefaultChainQuerier<S>>,
states: Rc<HashMap<String, Rc<RefCell<GeneralServiceState<DB>>>>>,
states: Rc<ServiceStateMap<DB>>,
root_state: Rc<RefCell<GeneralServiceState<DB>>>,
}

Expand All @@ -67,7 +167,7 @@ impl<S: 'static + Storage, DB: 'static + TrieDB, Mapping: 'static + ServiceMappi
) -> ProtocolResult<MerkleRoot> {
let querier = Rc::new(DefaultChainQuerier::new(Arc::clone(&storage)));

let mut states = HashMap::new();
let mut states = ServiceStateMap::new();
for name in mapping.list_service_name().into_iter() {
let trie = MPTTrie::new(Arc::clone(&trie_db));

Expand Down Expand Up @@ -111,7 +211,7 @@ impl<S: 'static + Storage, DB: 'static + TrieDB, Mapping: 'static + ServiceMappi
let trie = MPTTrie::from(root, Arc::clone(&trie_db))?;
let root_state = GeneralServiceState::new(trie);

let mut states = HashMap::new();
let mut states = ServiceStateMap::new();
for name in service_mapping.list_service_name().into_iter() {
let trie = match root_state.get(&name)? {
Some(service_root) => MPTTrie::from(service_root, Arc::clone(&trie_db))?,
Expand Down Expand Up @@ -141,19 +241,11 @@ impl<S: 'static + Storage, DB: 'static + TrieDB, Mapping: 'static + ServiceMappi
}

fn stash(&mut self) -> ProtocolResult<()> {
for state in self.states.values() {
state.borrow_mut().stash()?;
}

Ok(())
self.states.stash()
}

fn revert_cache(&mut self) -> ProtocolResult<()> {
for state in self.states.values() {
state.borrow_mut().revert_cache()?;
}

Ok(())
self.states.revert_cache()
}

#[muta_apm::derive::tracing_span(
Expand Down Expand Up @@ -236,20 +328,44 @@ impl<S: 'static + Storage, DB: 'static + TrieDB, Mapping: 'static + ServiceMappi
Ok(ServiceContext::new(ctx_params))
}

fn get_tx_hooks(&self, exec_type: ExecType) -> Box<dyn TxHooks> {
match exec_type {
ExecType::Read => Box::new(()),
ExecType::Write => {
let mut tx_hooks = vec![];

for name in self.service_mapping.list_service_name().into_iter() {
let sdk = self
.get_sdk(&name)
.unwrap_or_else(|e| panic!("get target service sdk failed: {}", e));

let tx_hook_service = self
.service_mapping
.get_service(name.as_str(), sdk)
.unwrap_or_else(|e| panic!("get target service sdk failed: {}", e));

tx_hooks.push(tx_hook_service);
}

let hooks = CommitHooks::new(tx_hooks, Rc::clone(&self.states));
Box::new(hooks)
}
}
}

fn catch_call(
&mut self,
context: ServiceContext,
context: Context,
service_context: ServiceContext,
exec_type: ExecType,
) -> ProtocolResult<ServiceResponse<String>> {
let result = match exec_type {
ExecType::Read => panic::catch_unwind(AssertUnwindSafe(|| {
self.call(context.clone(), exec_type.clone())
})),
ExecType::Write => panic::catch_unwind(AssertUnwindSafe(|| {
self.call_with_tx_hooks(context.clone(), exec_type.clone())
})),
};
match result {
let mut tx_hooks = self.get_tx_hooks(exec_type);

tx_hooks.before(context.clone(), service_context.clone())?;

let ret = match panic::catch_unwind(AssertUnwindSafe(|| {
self.call(service_context.clone(), exec_type)
})) {
Ok(r) => {
self.stash()?;
Ok(r)
Expand All @@ -259,38 +375,11 @@ impl<S: 'static + Storage, DB: 'static + TrieDB, Mapping: 'static + ServiceMappi
log::error!("inner chain error occurred when calling service: {:?}", e);
Err(ExecutorError::CallService(format!("{:?}", e)).into())
}
}
}
};

fn call_with_tx_hooks(
&self,
context: ServiceContext,
exec_type: ExecType,
) -> ServiceResponse<String> {
let mut tx_hook_services = vec![];
for name in self.service_mapping.list_service_name().into_iter() {
let sdk = self
.get_sdk(&name)
.unwrap_or_else(|e| panic!("get target service sdk failed: {}", e));
let tx_hook_service = self
.service_mapping
.get_service(name.as_str(), sdk)
.unwrap_or_else(|e| panic!("get target service sdk failed: {}", e));
tx_hook_services.push(tx_hook_service);
}
// TODO: If tx_hook_before_ failed, we should not exec the tx.
// Need a mechanism for this.
for tx_hook_service in tx_hook_services.iter_mut() {
tx_hook_service.tx_hook_before_(context.clone());
}
let original_res = self.call(context.clone(), exec_type);
// TODO: If the tx fails, status tx_hook_after_ changes will also be reverted.
// It may not be what the developer want.
// Need a new mechanism for this.
for tx_hook_service in tx_hook_services.iter_mut() {
tx_hook_service.tx_hook_after_(context.clone());
}
original_res
tx_hooks.after(context, service_context)?;

ret
}

fn call(&self, context: ServiceContext, exec_type: ExecType) -> ServiceResponse<String> {
Expand Down Expand Up @@ -325,7 +414,7 @@ impl<S: 'static + Storage, DB: 'static + TrieDB, Mapping: 'static + ServiceMappi
let mut receipts = txs
.iter()
.map(|stx| {
let context = self.get_context(
let service_context = self.get_context(
Some(stx.tx_hash.clone()),
Some(stx.raw.nonce.clone()),
&stx.raw.sender,
Expand All @@ -335,22 +424,23 @@ impl<S: 'static + Storage, DB: 'static + TrieDB, Mapping: 'static + ServiceMappi
&stx.raw.request,
)?;

let exec_resp = self.catch_call(context.clone(), ExecType::Write)?;
let exec_resp =
self.catch_call(ctx.clone(), service_context.clone(), ExecType::Write)?;
let events = if exec_resp.is_error() {
Vec::new()
} else {
context.get_events()
service_context.get_events()
};

Ok(Receipt {
state_root: MerkleRoot::from_empty(),
height: context.get_current_height(),
height: service_context.get_current_height(),
tx_hash: stx.tx_hash.clone(),
cycles_used: context.get_cycles_used(),
cycles_used: service_context.get_cycles_used(),
events,
response: ReceiptResponse {
service_name: context.get_service_name().to_owned(),
method: context.get_service_method().to_owned(),
service_name: service_context.get_service_name().to_owned(),
method: service_context.get_service_method().to_owned(),
response: exec_resp,
},
})
Expand Down
Loading