diff --git a/examples/name-service/programs/name-service/src/lib.rs b/examples/name-service/programs/name-service/src/lib.rs index e5af52d75f..1f7c11990c 100644 --- a/examples/name-service/programs/name-service/src/lib.rs +++ b/examples/name-service/programs/name-service/src/lib.rs @@ -86,6 +86,7 @@ pub enum CustomError { } #[light_accounts] +#[instruction(name: String)] pub struct CreateRecord<'info> { #[account(mut)] #[fee_payer] @@ -96,7 +97,7 @@ pub struct CreateRecord<'info> { #[authority] pub cpi_signer: AccountInfo<'info>, - #[light_account(init, seeds = [b"name-service", record.name.as_bytes()])] + #[light_account(init, seeds = [b"name-service", name.as_bytes()])] pub record: LightAccount, } diff --git a/macros/light-sdk-macros/src/accounts.rs b/macros/light-sdk-macros/src/accounts.rs index 7442064b07..070950f40a 100644 --- a/macros/light-sdk-macros/src/accounts.rs +++ b/macros/light-sdk-macros/src/accounts.rs @@ -1,11 +1,12 @@ use proc_macro2::{Span, TokenStream}; -use quote::quote; +use quote::{quote, ToTokens}; use syn::{ parse::{Parse, ParseStream}, parse_quote, punctuated::Punctuated, token::PathSep, - Error, Expr, Fields, Ident, ItemStruct, Meta, Path, PathSegment, Result, Token, Type, TypePath, + Error, Expr, Fields, Ident, ItemStruct, Meta, Path, PathSegment, Result, Stmt, Token, Type, + TypePath, }; pub(crate) fn process_light_system_accounts(input: ItemStruct) -> Result { @@ -75,6 +76,64 @@ pub(crate) fn process_light_system_accounts(input: ItemStruct) -> Result, + param_names: Vec, +} + +impl Parse for InstructionArgs { + fn parse(input: ParseStream) -> Result { + let mut param_type_checks = Vec::new(); + let mut param_names = Vec::new(); + + while !input.is_empty() { + let ident = input.parse::()?; + input.parse::()?; + let ty = input.parse::()?; + + param_names.push(ident.clone()); + param_type_checks.push(ParamTypeCheck { ident, ty }); + + if input.peek(Token![,]) { + input.parse::()?; + } + } + + Ok(InstructionArgs { + param_type_checks, + param_names, + }) + } +} + +/// Takes an input struct annotated with `#[light_accounts]` attribute and +/// then: +/// +/// - Creates a separate struct with `Light` prefix and moves compressed +/// account fields (annotated with `#[light_account]` attribute) to it. As a +/// result, the original struct, later processed by Anchor macros, contains +/// only regular accounts. +/// - Creates an extention trait, with `LightContextExt` prefix, which serves +/// as an extension to `LightContext` and defines these methods: +/// - `check_constraints`, where the checks extracted from `#[light_account]` +/// attributes are performed. +/// - `derive_address_seeds`, where the seeds extracted from +/// `#[light_account]` attributes are used to derive the address. pub(crate) fn process_light_accounts(input: ItemStruct) -> Result { let mut anchor_accounts_strct = input.clone(); @@ -82,6 +141,18 @@ pub(crate) fn process_light_accounts(input: ItemStruct) -> Result { let anchor_accounts_name = input.ident.clone(); let light_accounts_name = Ident::new(&format!("Light{}", input.ident), Span::call_site()); + let ext_trait_name = Ident::new( + &format!("LightContextExt{}", input.ident), + Span::call_site(), + ); + let params_name = Ident::new(&format!("Params{}", input.ident), Span::call_site()); + + let instruction_params = input + .attrs + .iter() + .find(|attribute| attribute.path().is_ident("instruction")) + .map(|attribute| attribute.parse_args::()) + .transpose()?; let mut light_accounts_fields: Punctuated = Punctuated::new(); @@ -94,11 +165,18 @@ pub(crate) fn process_light_accounts(input: ItemStruct) -> Result { )), }; + // Fields which should belong to the Anchor instruction struct. let mut anchor_fields = Punctuated::new(); + // Names of fields which should belong to the Anchor instruction struct. let mut anchor_field_idents = Vec::new(); + // Names of fields which should belong to the Light instruction struct. let mut light_field_idents = Vec::new(); + // Names of fields of the Light instruction struct, which should be + // available in constraints. + let mut light_referrable_field_idents = Vec::new(); let mut constraint_calls = Vec::new(); let mut derive_address_seed_calls = Vec::new(); + let mut set_address_seed_calls = Vec::new(); for field in fields.named.iter() { let mut light_account = false; @@ -135,6 +213,10 @@ pub(crate) fn process_light_accounts(input: ItemStruct) -> Result { } }; + if account_args.action != LightAccountAction::Init { + light_referrable_field_idents.push(field.ident.clone()); + } + if let Some(constraint) = account_args.constraint { let Constraint { expr, error } = constraint; let error = match error { @@ -157,8 +239,10 @@ pub(crate) fn process_light_accounts(input: ItemStruct) -> Result { &crate::ID, &unpacked_address_merkle_context, ); - #field_ident.set_address_seed(address_seed); }); + set_address_seed_calls.push(quote! { + #field_ident.set_address_seed(address_seed); + }) } else { anchor_fields.push(field.clone()); anchor_field_idents.push(field.ident.clone()); @@ -181,6 +265,27 @@ pub(crate) fn process_light_accounts(input: ItemStruct) -> Result { } }; + let light_referrable_fields = if light_referrable_field_idents.is_empty() { + quote! {} + } else { + quote! { + let #light_accounts_name { + #(#light_referrable_field_idents),*, .. + } = &self.light_accounts; + } + }; + let input_fields = match instruction_params { + Some(instruction_params) => { + let param_names = instruction_params.param_names; + let param_type_checks = instruction_params.param_type_checks; + quote! { + let #params_name { #(#param_names),*, .. } = inputs; + #(#param_type_checks)* + } + } + None => quote! {}, + }; + let expanded = quote! { #[::light_sdk::light_system_accounts] #[derive(::anchor_lang::Accounts, ::light_sdk::LightTraits)] @@ -188,14 +293,32 @@ pub(crate) fn process_light_accounts(input: ItemStruct) -> Result { #light_accounts_strct - impl<'a, 'b, 'c, 'info> LightContextExt for ::light_sdk::context::LightContext< + pub trait #ext_trait_name { + fn check_constraints( + &self, + inputs: &#params_name, + ) -> Result<()>; + fn derive_address_seeds( + &mut self, + address_merkle_context: ::light_sdk::merkle_context::PackedAddressMerkleContext, + inputs: &#params_name, + ); + } + + impl<'a, 'b, 'c, 'info> #ext_trait_name for ::light_sdk::context::LightContext< 'a, 'b, 'c, 'info, #anchor_accounts_name #type_gen, #light_accounts_name, > { #[allow(unused_parens)] #[allow(unused_variables)] - fn check_constraints(&self) -> Result<()> { - let #anchor_accounts_name { #(#anchor_field_idents),*, .. } = &self.anchor_context.accounts; - let #light_accounts_name { #(#light_field_idents),* } = &self.light_accounts; + fn check_constraints( + &self, + inputs: &#params_name, + ) -> Result<()> { + let #anchor_accounts_name { + #(#anchor_field_idents),*, .. + } = &self.anchor_context.accounts; + #light_referrable_fields + #input_fields #(#constraint_calls)* @@ -206,15 +329,23 @@ pub(crate) fn process_light_accounts(input: ItemStruct) -> Result { fn derive_address_seeds( &mut self, address_merkle_context: PackedAddressMerkleContext, + inputs: &#params_name, ) { - let #anchor_accounts_name { #(#anchor_field_idents),*, .. } = &self.anchor_context.accounts; - let #light_accounts_name { #(#light_field_idents),* } = &mut self.light_accounts; + let #anchor_accounts_name { + #(#anchor_field_idents),*, .. + } = &self.anchor_context.accounts; + #light_referrable_fields + #input_fields let unpacked_address_merkle_context = ::light_sdk::program_merkle_context::unpack_address_merkle_context( address_merkle_context, self.anchor_context.remaining_accounts); #(#derive_address_seed_calls)* + + let #light_accounts_name { #(#light_field_idents),* } = &mut self.light_accounts; + + #(#set_address_seed_calls)* } } }; @@ -222,7 +353,7 @@ pub(crate) fn process_light_accounts(input: ItemStruct) -> Result { Ok(expanded) } -mod kw { +mod light_account_kw { // Action syn::custom_keyword!(init); syn::custom_keyword!(close); @@ -232,6 +363,7 @@ mod kw { syn::custom_keyword!(seeds); } +#[derive(Eq, PartialEq)] pub(crate) enum LightAccountAction { Init, Mut, @@ -263,20 +395,20 @@ impl Parse for LightAccountArgs { let lookahead = input.lookahead1(); // Actions - if lookahead.peek(kw::init) { - input.parse::()?; + if lookahead.peek(light_account_kw::init) { + input.parse::()?; action = Some(LightAccountAction::Init); } else if lookahead.peek(Token![mut]) { input.parse::()?; action = Some(LightAccountAction::Mut); - } else if lookahead.peek(kw::close) { - input.parse::()?; + } else if lookahead.peek(light_account_kw::close) { + input.parse::()?; action = Some(LightAccountAction::Close); } // Constraint - else if lookahead.peek(kw::constraint) { + else if lookahead.peek(light_account_kw::constraint) { // Parse the constraint. - input.parse::()?; + input.parse::()?; input.parse::()?; let expr: Expr = input.parse()?; @@ -290,8 +422,8 @@ impl Parse for LightAccountArgs { constraint = Some(Constraint { expr, error }); } // Seeds - else if lookahead.peek(kw::seeds) { - input.parse::()?; + else if lookahead.peek(light_account_kw::seeds) { + input.parse::()?; input.parse::()?; seeds = Some(input.parse::()?); } else { @@ -427,7 +559,7 @@ pub(crate) fn process_light_accounts_derive(input: ItemStruct) -> Result>, +/// name: String, +/// num: u32, +/// ) -> Result<()> {} +/// +/// pub fn instruction_two( +/// ctx: LightContext<'_, '_, '_, 'info, InstructionTwo<'info>>, +/// num_one: u32, +/// num_two: u64, +/// ) -> Result<()> {} +/// } +/// ``` +/// +/// The mapping is going to look like: +/// +/// ``` +/// instruction_one -> - name: name +/// ty: String +/// - name: num +/// ty: u32 +/// +/// instruction_two -> - name: num_one +/// ty: u32 +/// - name: num_two +/// ty: u64 +/// ``` #[derive(Default)] -struct LightProgramTransform {} +struct InstructionParams(HashMap>); + +impl Deref for InstructionParams { + type Target = HashMap>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for InstructionParams { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +/// Implementation of `ToTokens` which allows to convert the +/// instruction-parameter mapping to structs, which we later use for packing +/// of parameters for convenient usage in `LightContext` extensions produced in +/// `accounts.rs` - precisely, in the `check_constraints` and +/// `derive_address_seeds` methods. +impl ToTokens for InstructionParams { + fn to_tokens(&self, tokens: &mut TokenStream) { + for (name, inputs) in self.0.iter() { + let name = Ident::new(name, Span::call_site()); + let strct: ItemStruct = parse_quote! { + pub struct #name { + #(#inputs),* + } + }; + strct.to_tokens(tokens); + } + } +} + +#[derive(Default)] +struct LightProgramTransform { + /// Mapping of instructions to their parameters in the program. + instruction_params: InstructionParams, +} impl VisitMut for LightProgramTransform { fn visit_item_fn_mut(&mut self, i: &mut ItemFn) { @@ -18,6 +117,20 @@ impl VisitMut for LightProgramTransform { }; i.attrs.push(clippy_attr); + // Gather names instruction parameters (arguments other than `ctx`). + // They are going to be used to generate `Inputs*` structs. + let mut instruction_params = Vec::with_capacity(i.sig.inputs.len() - 1); + for input in i.sig.inputs.iter().skip(1) { + if let FnArg::Typed(input) = input { + if let Pat::Ident(ref pat_ident) = *input.pat { + instruction_params.push(InstructionParam { + name: pat_ident.ident.clone(), + ty: (*input.ty).clone(), + }); + } + } + } + // Find the `ctx` argument. let ctx_arg = i.sig.inputs.first_mut().unwrap(); @@ -62,6 +175,13 @@ impl VisitMut for LightProgramTransform { let light_accounts_name = format!("Light{}", accounts_segment.ident); let light_accounts_ident = Ident::new(&light_accounts_name, Span::call_site()); + // Add the previously gathered instruction inputs to the mapping of + // instructions to their parameters (`self.instruction_inputs`). + let params_name = format!("Params{}", accounts_segment.ident); + self.instruction_params + .insert(params_name.clone(), instruction_params.clone()); + let inputs_ident = Ident::new(¶ms_name, Span::call_site()); + // Inject an `inputs: Vec>` argument to all instructions. The // purpose of that additional argument is passing compressed accounts. let inputs_arg: FnArg = parse_quote! { inputs: Vec> }; @@ -97,21 +217,46 @@ impl VisitMut for LightProgramTransform { }; i.block.stmts.insert(0, light_context_stmt); - // Inject `check_constraints` call right after. - let check_constraints_stmt: Stmt = parse_quote! { - ctx.check_constraints()?; + // Pack all instruction inputs in a struct, which then can be used in + // `check_constrants` and `derive_address_seeds`. + // + // We do that, because passing one reference to these methods is more + // comfortable. Passing references to each input separately would + // require even messier code... + // + // We move the inputs to that struct, so no copies are being made. + let input_idents = instruction_params + .iter() + .map(|input| input.name.clone()) + .collect::>(); + let inputs_pack_stmt: Stmt = parse_quote! { + let inputs = #inputs_ident { #(#input_idents),* }; }; - i.block.stmts.insert(1, check_constraints_stmt); + i.block.stmts.insert(1, inputs_pack_stmt); - // Inject `derive_address_seeds` and `verify` statements at the end of - // the function. - let stmts_len = i.block.stmts.len(); + // Inject `check_constraints` and `derive_address_seeds` calls right + // after. + let check_constraints_stmt: Stmt = parse_quote! { + ctx.check_constraints(&inputs)?; + }; + i.block.stmts.insert(2, check_constraints_stmt); let derive_address_seed_stmt: Stmt = parse_quote! { - ctx.derive_address_seeds(address_merkle_context); + ctx.derive_address_seeds(address_merkle_context, &inputs); }; - i.block - .stmts - .insert(stmts_len - 1, derive_address_seed_stmt); + i.block.stmts.insert(3, derive_address_seed_stmt); + + // Once we are done with calling `check_constraints` and + // `derive_address_seeds`, we can unpack the inputs, so developers can + // use them as regular variables in their code. + // + // Unpacking of the struct means moving the values and no copies are + // being made. + let inputs_unpack_stmt: Stmt = parse_quote! { + let #inputs_ident { #(#input_idents),* } = inputs; + }; + i.block.stmts.insert(4, inputs_unpack_stmt); + + // Inject `verify` statements at the end of the function. let stmts_len = i.block.stmts.len(); let verify_stmt: Stmt = parse_quote! { ctx.verify(proof)?; @@ -135,11 +280,10 @@ pub(crate) fn program(mut input: ItemMod) -> Result { let mut transform = LightProgramTransform::default(); transform.visit_item_mod_mut(&mut input); + let instruction_params = transform.instruction_params; + Ok(quote! { - pub trait LightContextExt { - fn check_constraints(&self) -> Result<()>; - fn derive_address_seeds(&mut self, address_merkle_context: PackedAddressMerkleContext); - } + #instruction_params #[program] #input