Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set prot cfg fields in test #13178

Merged
merged 6 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 51 additions & 4 deletions crates/sui-protocol-config-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ use syn::{parse_macro_input, Data, DeriveInput, Fields, Type};
/// /// Returns a map of all features to values
/// pub fn feature_map(&self) -> std::collections::BTreeMap<String, bool>;
/// ```
#[proc_macro_derive(ProtocolConfigGetters)]
pub fn getters_macro(input: TokenStream) -> TokenStream {
#[proc_macro_derive(ProtocolConfigAccessors)]
pub fn accessors_macro(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);

let struct_name = &ast.ident;
Expand Down Expand Up @@ -82,6 +82,12 @@ pub fn getters_macro(input: TokenStream) -> TokenStream {
let as_option_name = format!("{field_name}_as_option");
let as_option_name: proc_macro2::TokenStream =
as_option_name.parse().unwrap();
let test_setter_name: proc_macro2::TokenStream =
format!("set_{field_name}_for_testing").parse().unwrap();
let test_un_setter_name: proc_macro2::TokenStream =
format!("disable_{field_name}_for_testing").parse().unwrap();
let test_setter_from_str_name: proc_macro2::TokenStream =
format!("set_{field_name}_from_str_for_testing").parse().unwrap();

let getter = quote! {
// Derive the getter
Expand All @@ -94,6 +100,29 @@ pub fn getters_macro(input: TokenStream) -> TokenStream {
}
};

let test_setter = quote! {
// Derive the setter
pub fn #test_setter_name(&mut self, val: #inner_type) {
self.#field_name = Some(val);
}

// Derive the setter from String
pub fn #test_setter_from_str_name(&mut self, val: String) {
use std::str::FromStr;
self.#test_setter_name(#inner_type::from_str(&val).unwrap());
}

// Derive the un-setter
pub fn #test_un_setter_name(&mut self) {
self.#field_name = None;
}
};

let value_setter = quote! {
stringify!(#field_name) => self.#test_setter_from_str_name(val),
};


let value_lookup = quote! {
stringify!(#field_name) => self.#field_name.map(|v| ProtocolConfigValue::#inner_type(v)),
};
Expand All @@ -112,7 +141,7 @@ pub fn getters_macro(input: TokenStream) -> TokenStream {
})
};

Some((getter, (value_lookup, field_name_str)))
Some(((getter, (test_setter, value_setter)), (value_lookup, field_name_str)))
}
_ => None,
}
Expand All @@ -121,7 +150,12 @@ pub fn getters_macro(input: TokenStream) -> TokenStream {
},
_ => panic!("Only structs supported."),
};
let (getters, (value_lookup, field_names_str)): (Vec<_>, (Vec<_>, Vec<_>)) = tokens.unzip();

#[allow(clippy::type_complexity)]
let ((getters, (test_setters, value_setters)), (value_lookup, field_names_str)): (
(Vec<_>, (Vec<_>, Vec<_>)),
(Vec<_>, Vec<_>),
) = tokens.unzip();
let inner_types = Vec::from_iter(seen_types);
let output = quote! {
// For each getter, expand it out into a function in the impl block
Expand Down Expand Up @@ -154,6 +188,19 @@ pub fn getters_macro(input: TokenStream) -> TokenStream {
}
}

// For each attr, derive a setter from the raw value and from string repr
#[cfg(debug_assertions)]
impl #struct_name {
#(#test_setters)*

pub fn set_attr_for_testing(&mut self, attr: String, val: String) {
match attr.as_str() {
#(#value_setters)*
_ => panic!("Attempting to set unknown attribute: {}", attr),
}
}
}

#[allow(non_camel_case_types)]
#[derive(Clone, Serialize, Debug, PartialEq, Deserialize, schemars::JsonSchema)]
pub enum ProtocolConfigValue {
Expand Down
33 changes: 19 additions & 14 deletions crates/sui-protocol-config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use std::cell::RefCell;
use std::sync::atomic::{AtomicBool, Ordering};
use sui_protocol_config_macros::{ProtocolConfigFeatureFlagsGetters, ProtocolConfigGetters};
use sui_protocol_config_macros::{ProtocolConfigAccessors, ProtocolConfigFeatureFlagsGetters};
use tracing::{info, warn};

/// The minimum and maximum protocol versions supported by this build.
Expand Down Expand Up @@ -303,7 +303,7 @@ impl ConsensusTransactionOrdering {
/// return `None` if the field is not defined at that version.
/// - If you want a customized getter, you can add a method in the impl.
#[skip_serializing_none]
#[derive(Clone, Serialize, Debug, ProtocolConfigGetters)]
#[derive(Clone, Serialize, Debug, ProtocolConfigAccessors)]
pub struct ProtocolConfig {
pub version: ProtocolVersion,

Expand Down Expand Up @@ -1384,12 +1384,6 @@ impl ProtocolConfig {

// Setters for tests
impl ProtocolConfig {
pub fn set_max_function_definitions_for_testing(&mut self, m: u64) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed these as they're now auto-derived

self.max_function_definitions = Some(m)
}
pub fn set_buffer_stake_for_protocol_upgrade_bps_for_testing(&mut self, b: u64) {
self.buffer_stake_for_protocol_upgrade_bps = Some(b)
}
pub fn set_package_upgrades_for_testing(&mut self, val: bool) {
self.feature_flags.package_upgrades = val
}
Expand All @@ -1403,12 +1397,6 @@ impl ProtocolConfig {
pub fn set_zklogin_auth(&mut self, val: bool) {
self.feature_flags.zklogin_auth = val
}
pub fn set_max_tx_gas_for_testing(&mut self, max_tx_gas: u64) {
self.max_tx_gas = Some(max_tx_gas)
}
pub fn set_execution_version_for_testing(&mut self, version: u64) {
self.execution_version = Some(version)
}
pub fn set_upgraded_multisig_for_testing(&mut self, val: bool) {
self.feature_flags.upgraded_multisig_supported = val
}
Expand Down Expand Up @@ -1546,6 +1534,23 @@ mod test {
);
}

#[test]
fn test_setters() {
let mut prot: ProtocolConfig =
ProtocolConfig::get_for_version(ProtocolVersion::new(1), Chain::Unknown);
prot.set_max_arguments_for_testing(123);
assert_eq!(prot.max_arguments(), 123);

prot.set_max_arguments_from_str_for_testing("321".to_string());
assert_eq!(prot.max_arguments(), 321);

prot.disable_max_arguments_for_testing();
assert_eq!(prot.max_arguments_as_option(), None);

prot.set_attr_for_testing("max_arguments".to_string(), "456".to_string());
assert_eq!(prot.max_arguments(), 456);
}

#[test]
fn lookup_by_string_test() {
let prot: ProtocolConfig =
Expand Down