Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add MPCContext derive + test utils refactor #497

Merged
merged 4 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions macros/context-derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ gadget-sdk = { path = "../../sdk", features = ["std"] }
alloy-network = { workspace = true }
alloy-provider = { workspace = true }
alloy-transport = { workspace = true }
round-based = { workspace = true }
serde = { workspace = true }

[features]
default = ["std"]
Expand Down
15 changes: 15 additions & 0 deletions macros/context-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ mod eigenlayer;
mod evm;
/// Keystore context extension implementation.
mod keystore;
/// MPC context extension implementation.
mod mpc;
/// Services context extension implementation.
mod services;
/// Tangle Subxt Client context extension implementation.
Expand Down Expand Up @@ -88,3 +90,16 @@ pub fn derive_eigenlayer_context(input: TokenStream) -> TokenStream {
Err(err) => TokenStream::from(err.to_compile_error()),
}
}

/// Derive macro for generating Context Extensions trait implementation for `MPCContext`.
#[proc_macro_derive(MPCContext, attributes(config))]
pub fn derive_mpc_context(input: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(input as syn::DeriveInput);
let result = cfg::find_config_field(&input.ident, &input.data)
.map(|config_field| mpc::generate_context_impl(input, config_field));

match result {
Ok(expanded) => TokenStream::from(expanded),
Err(err) => TokenStream::from(err.to_compile_error()),
}
}
158 changes: 158 additions & 0 deletions macros/context-derive/src/mpc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
use quote::quote;
use syn::DeriveInput;

use crate::cfg::FieldInfo;

/// Generate the `MPCContext` implementation for the given struct.
#[allow(clippy::too_many_lines)]
pub fn generate_context_impl(
DeriveInput {
ident: name,
generics,
..
}: DeriveInput,
config_field: FieldInfo,
) -> proc_macro2::TokenStream {
let _field_access = match config_field {
FieldInfo::Named(ident) => quote! { self.#ident },
FieldInfo::Unnamed(index) => quote! { self.#index },
};

let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

quote! {
#[gadget_sdk::async_trait::async_trait]
impl #impl_generics gadget_sdk::ctx::MPCContext for #name #ty_generics #where_clause {
/// Returns a reference to the configuration
#[inline]
fn config(&self) -> &gadget_sdk::config::StdGadgetConfiguration {
&self.config
}

/// Returns the network protocol identifier for this context
#[inline]
fn network_protocol(&self) -> String {
let name = stringify!(#name).to_string();
format!("/{}/1.0.0", name.to_lowercase())
}

fn create_network_delivery_wrapper<M>(
&self,
mux: std::sync::Arc<gadget_sdk::network::NetworkMultiplexer>,
party_index: gadget_sdk::round_based::PartyIndex,
task_hash: [u8; 32],
parties: std::collections::BTreeMap<gadget_sdk::round_based::PartyIndex, gadget_sdk::subxt_core::ext::sp_core::ecdsa::Public>,
) -> Result<gadget_sdk::network::round_based_compat::NetworkDeliveryWrapper<M>, gadget_sdk::Error>
where
M: Clone + Send + Unpin + 'static + gadget_sdk::serde::Serialize + gadget_sdk::serde::de::DeserializeOwned + gadget_sdk::round_based::ProtocolMessage,
{
Ok(gadget_sdk::network::round_based_compat::NetworkDeliveryWrapper::new(mux, party_index, task_hash, parties))
}

async fn get_party_index(
&self,
) -> Result<gadget_sdk::round_based::PartyIndex, gadget_sdk::Error> {
Ok(self.get_party_index_and_operators().await?.0 as _)
}

async fn get_participants(
&self,
client: &gadget_sdk::ext::subxt::OnlineClient<gadget_sdk::clients::tangle::runtime::TangleConfig>,
) -> Result<
std::collections::BTreeMap<gadget_sdk::round_based::PartyIndex, gadget_sdk::subxt::utils::AccountId32>,
gadget_sdk::Error,
> {
Ok(self.get_party_index_and_operators().await?.1.into_iter().enumerate().map(|(i, (id, _))| (i as _, id)).collect())
}

/// Retrieves the current blueprint ID from the configuration
///
/// # Errors
/// Returns an error if the blueprint ID is not found in the configuration
fn blueprint_id(&self) -> gadget_sdk::color_eyre::Result<u64> {
self.config()
.protocol_specific
.tangle()
.map(|c| c.blueprint_id)
.map_err(|err| gadget_sdk::color_eyre::Report::msg("Blueprint ID not found in configuration: {err}"))
}

/// Retrieves the current party index and operator mapping
///
/// # Errors
/// Returns an error if:
/// - Failed to retrieve operator keys
/// - Current party is not found in the operator list
async fn get_party_index_and_operators(
&self,
) -> gadget_sdk::color_eyre::Result<(usize, std::collections::BTreeMap<gadget_sdk::subxt::utils::AccountId32, gadget_sdk::subxt_core::ext::sp_core::ecdsa::Public>)> {
let parties = self.current_service_operators_ecdsa_keys().await?;
let my_id = self.config.first_sr25519_signer()?.account_id();

gadget_sdk::trace!(
"Looking for {my_id:?} in parties: {:?}",
parties.keys().collect::<Vec<_>>()
);

let index_of_my_id = parties
.iter()
.position(|(id, _)| id == &my_id)
.ok_or_else(|| gadget_sdk::color_eyre::Report::msg("Party not found in operator list"))?;

Ok((index_of_my_id, parties))
}

/// Retrieves the ECDSA keys for all current service operators
///
/// # Errors
/// Returns an error if:
/// - Failed to connect to the Tangle client
/// - Failed to retrieve operator information
/// - Missing ECDSA key for any operator
async fn current_service_operators_ecdsa_keys(
&self,
) -> gadget_sdk::color_eyre::Result<std::collections::BTreeMap<gadget_sdk::subxt::utils::AccountId32, gadget_sdk::subxt_core::ext::sp_core::ecdsa::Public>> {
let client = self.tangle_client().await?;
let current_blueprint = self.blueprint_id()?;
let current_service_op = self.current_service_operators(&client).await?;
let storage = client.storage().at_latest().await?;

let mut map = std::collections::BTreeMap::new();
for (operator, _) in current_service_op {
let addr = gadget_sdk::ext::tangle_subxt::tangle_testnet_runtime::api::storage()
.services()
.operators(current_blueprint, &operator);

let maybe_pref = storage.fetch(&addr).await.map_err(|err| {
gadget_sdk::color_eyre::Report::msg("Failed to fetch operator storage for {operator}: {err}")
})?;

if let Some(pref) = maybe_pref {
map.insert(operator, gadget_sdk::subxt_core::ext::sp_core::ecdsa::Public(pref.key));
} else {
return Err(gadget_sdk::color_eyre::Report::msg("Missing ECDSA key for operator {operator}"));
}
}

Ok(map)
}

/// Retrieves the current call ID for this job
///
/// # Errors
/// Returns an error if failed to retrieve the call ID from storage
async fn current_call_id(&self) -> gadget_sdk::color_eyre::Result<u64> {
let client = self.tangle_client().await?;
let addr = gadget_sdk::ext::tangle_subxt::tangle_testnet_runtime::api::storage().services().next_job_call_id();
let storage = client.storage().at_latest().await?;

let maybe_call_id = storage
.fetch_or_default(&addr)
.await
.map_err(|err| gadget_sdk::color_eyre::Report::msg("Failed to fetch current call ID: {err}"))?;

Ok(maybe_call_id.saturating_sub(1))
}
}
}
}
14 changes: 8 additions & 6 deletions macros/context-derive/tests/tests.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
mod ui;

#[cfg(test)]
mod tests {
#[test]
fn test_derive_context() {
let t = trybuild::TestCases::new();
t.pass("tests/ui/01_basic.rs");
t.pass("tests/ui/02_unnamed_fields.rs");
t.pass("tests/ui/03_generic_struct.rs");
t.compile_fail("tests/ui/04_missing_config_attr.rs");
t.compile_fail("tests/ui/05_not_a_struct.rs");
t.compile_fail("tests/ui/06_unit_struct.rs");
t.pass("tests/ui/basic.rs");
t.pass("tests/ui/unnamed_fields.rs");
t.pass("tests/ui/generic_struct.rs");
t.compile_fail("tests/ui/missing_config_attr.rs");
t.compile_fail("tests/ui/not_a_struct.rs");
t.compile_fail("tests/ui/unit_struct.rs");
}
}
24 changes: 0 additions & 24 deletions macros/context-derive/tests/ui/01_basic.rs

This file was deleted.

98 changes: 98 additions & 0 deletions macros/context-derive/tests/ui/basic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use gadget_sdk::async_trait::async_trait;
use gadget_sdk::config::{GadgetConfiguration, StdGadgetConfiguration};
use gadget_sdk::ctx::{
EVMProviderContext, KeystoreContext, MPCContext, ServicesContext, TangleClientContext,
};
use gadget_sdk::network::{Network, NetworkMultiplexer, ProtocolMessage};
use gadget_sdk::store::LocalDatabase;
use gadget_sdk::subxt_core::ext::sp_core::ecdsa::Public;
use gadget_sdk::subxt_core::tx::signer::Signer;
use gadget_sdk::Error;
use round_based::ProtocolMessage as RoundBasedProtocolMessage;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::sync::Arc;

#[derive(KeystoreContext, EVMProviderContext, TangleClientContext, ServicesContext, MPCContext)]
#[allow(dead_code)]
struct MyContext {
foo: String,
#[config]
config: StdGadgetConfiguration,
store: Arc<LocalDatabase<u64>>,
}

#[allow(dead_code)]
fn main() {
let body = async {
let ctx = MyContext {
foo: "bar".to_string(),
config: GadgetConfiguration::default(),
store: Arc::new(LocalDatabase::open("test.json")),
};

// Test existing context functions
let _keystore = ctx.keystore();
let _evm_provider = ctx.evm_provider().await;
let tangle_client = ctx.tangle_client().await.unwrap();
let _services = ctx.current_service_operators(&tangle_client).await.unwrap();

// Test MPC context utility functions
let _config = ctx.config();
let _protocol = ctx.network_protocol();

// Test MPC context functions

let mux = Arc::new(NetworkMultiplexer::new(StubNetwork));
let party_index = 0;
let task_hash = [0u8; 32];
let mut parties = BTreeMap::<u16, _>::new();
parties.insert(0, Public([0u8; 33]));

// Test network delivery wrapper creation
let _network_wrapper = ctx.create_network_delivery_wrapper::<StubMessage>(
mux.clone(),
party_index,
task_hash,
parties.clone(),
);

// Test party index retrieval
let _party_idx = ctx.get_party_index().await;

// Test participants retrieval
let _participants = ctx.get_participants(&tangle_client).await;

// Test blueprint ID retrieval
let _blueprint_id = ctx.blueprint_id();

// Test party index and operators retrieval
let _party_idx_ops = ctx.get_party_index_and_operators().await;

// Test service operators ECDSA keys retrieval
let _operator_keys = ctx.current_service_operators_ecdsa_keys().await;

// Test current call ID retrieval
let _call_id = ctx.current_call_id().await;
};

drop(body);
}

#[derive(RoundBasedProtocolMessage, Clone, Serialize, Deserialize)]
enum StubMessage {}

#[allow(dead_code)]
struct StubNetwork;

#[async_trait]
impl Network for StubNetwork {
async fn next_message(&self) -> Option<ProtocolMessage> {
None
}

async fn send_message(&self, message: ProtocolMessage) -> Result<(), Error> {
drop(message);
Ok(())
}
}
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
use gadget_sdk::config::StdGadgetConfiguration;
use gadget_sdk::config::{GadgetConfiguration, StdGadgetConfiguration};
use gadget_sdk::ctx::{EVMProviderContext, KeystoreContext, ServicesContext, TangleClientContext};

#[derive(KeystoreContext, EVMProviderContext, TangleClientContext, ServicesContext)]
#[allow(dead_code)]
struct MyContext<T, U> {
foo: T,
bar: U,
#[config]
sdk_config: StdGadgetConfiguration,
}

#[allow(dead_code)]
fn main() {
let body = async {
let ctx = MyContext {
foo: "bar".to_string(),
bar: 42,
sdk_config: Default::default(),
sdk_config: GadgetConfiguration::default(),
};
let _keystore = ctx.keystore();
let _evm_provider = ctx.evm_provider().await.unwrap();
let tangle_client = ctx.tangle_client().await.unwrap();
let _services = ctx.current_service_operators(&tangle_client).await.unwrap();
};

let _ = body;
drop(body);
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
error: No field with #[config] attribute found, please add #[config] to the field that holds the `gadget_sdk::config::GadgetConfiguration`
--> tests/ui/04_missing_config_attr.rs:5:8
--> tests/ui/missing_config_attr.rs:5:8
|
5 | struct MyContext {
| ^^^^^^^^^
3 changes: 3 additions & 0 deletions macros/context-derive/tests/ui/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod basic;
mod generic_struct;
mod unnamed_fields;
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
error: Context Extensions traits can only be derived for structs
--> tests/ui/05_not_a_struct.rs:4:6
--> tests/ui/not_a_struct.rs:4:6
|
4 | enum MyContext {
| ^^^^^^^^^
Loading
Loading