Skip to content

Commit

Permalink
feat: macros that decoding arguments can set custom decoder using dec…
Browse files Browse the repository at this point in the history
…ode_with (#544)

* cleanup dfn_macro internal

* no arg decoding if function sig has no args

* name check

* decode_with: set custom arg decoder
  • Loading branch information
lwshang authored Jan 10, 2025
1 parent 14e10db commit 3184c9c
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 72 deletions.
28 changes: 28 additions & 0 deletions e2e-tests/src/bin/macros.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use candid::utils::{decode_args, decode_one};
use ic_cdk::api::msg_arg_data;
use ic_cdk::update;

#[update(decode_with = "decode_u0")]
fn u0() {}
fn decode_u0() {}

#[update(decode_with = "decode_u1")]
fn u1(a: u32) {
assert_eq!(a, 1)
}
fn decode_u1() -> u32 {
let arg_bytes = msg_arg_data();
decode_one(&arg_bytes).unwrap()
}

#[update(decode_with = "decode_u2")]
fn u2(a: u32, b: u32) {
assert_eq!(a, 1);
assert_eq!(b, 2);
}
fn decode_u2() -> (u32, u32) {
let arg_bytes = msg_arg_data();
decode_args(&arg_bytes).unwrap()
}

fn main() {}
31 changes: 31 additions & 0 deletions e2e-tests/tests/macros.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use pocket_ic::call_candid;
use pocket_ic::common::rest::RawEffectivePrincipal;

mod test_utilities;
use test_utilities::{cargo_build_canister, pocket_ic};

#[test]
fn call_macros() {
let pic = pocket_ic();
let wasm = cargo_build_canister("macros");
let canister_id = pic.create_canister();
pic.add_cycles(canister_id, 100_000_000_000_000);
pic.install_canister(canister_id, wasm, vec![], None);
let _: () = call_candid(&pic, canister_id, RawEffectivePrincipal::None, "u0", ()).unwrap();
let _: () = call_candid(
&pic,
canister_id,
RawEffectivePrincipal::None,
"u1",
(1u32,),
)
.unwrap();
let _: () = call_candid(
&pic,
canister_id,
RawEffectivePrincipal::None,
"u2",
(1u32, 2u32),
)
.unwrap();
}
184 changes: 112 additions & 72 deletions ic-cdk-macros/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use syn::{spanned::Spanned, FnArg, ItemFn, Pat, PatIdent, PatType, ReturnType, S
struct ExportAttributes {
pub name: Option<String>,
pub guard: Option<String>,
pub decode_with: Option<String>,
#[serde(default)]
pub manual_reply: bool,
#[serde(default)]
Expand All @@ -31,6 +32,12 @@ enum MethodType {
}

impl MethodType {
/// A lifecycle method is a method that is called by the system and not by the user.
/// So far, `update` and `query` are the only methods that are not lifecycle methods.
///
/// We have a few assumptions for lifecycle methods:
/// - They cannot have a return value.
/// - The export name is prefixed with `canister_`, e.g. `init` => `canister_init`.
pub fn is_lifecycle(&self) -> bool {
match self {
MethodType::Init
Expand All @@ -42,6 +49,19 @@ impl MethodType {
MethodType::Update | MethodType::Query => false,
}
}

/// init, post_upgrade, update, query can have arguments.
pub fn can_have_args(&self) -> bool {
match self {
MethodType::Init | MethodType::PostUpgrade | MethodType::Update | MethodType::Query => {
true
}
MethodType::PreUpgrade
| MethodType::Heartbeat
| MethodType::InspectMessage
| MethodType::OnLowWasmMemory => false,
}
}
}

impl std::fmt::Display for MethodType {
Expand Down Expand Up @@ -121,79 +141,38 @@ fn dfn_macro(
));
}

let is_async = signature.asyncness.is_some();

let return_length = match &signature.output {
ReturnType::Default => 0,
ReturnType::Type(_, ty) => match ty.as_ref() {
Type::Tuple(tuple) => tuple.elems.len(),
_ => 1,
},
};

if method.is_lifecycle() && return_length > 0 {
return Err(Error::new(
Span::call_site(),
format!("#[{}] function cannot have a return value.", method),
));
}

let (arg_tuple, _): (Vec<Ident>, Vec<Box<Type>>) =
get_args(method, signature)?.iter().cloned().unzip();
// 1. function name(s)
let name = &signature.ident;

let outer_function_ident = format_ident!("__canister_method_{name}");

let function_name = attrs.name.unwrap_or_else(|| name.to_string());
let export_name = if method.is_lifecycle() {
format!("canister_{}", method)
} else if method == MethodType::Query && attrs.composite {
format!("canister_composite_query {function_name}",)
} else {
if function_name.starts_with("<ic-cdk internal>") {
let function_name = if let Some(custom_name) = attrs.name {
if method.is_lifecycle() {
return Err(Error::new(
Span::call_site(),
attr.span(),
format!("#[{0}] cannot have a custom name.", method),
));
}
if custom_name.starts_with("<ic-cdk internal>") {
return Err(Error::new(
attr.span(),
"Functions starting with `<ic-cdk internal>` are reserved for CDK internal use.",
));
}
format!("canister_{method} {function_name}")
};
let host_compatible_name = export_name.replace(' ', ".").replace(['-', '<', '>'], "_");

let function_call = if is_async {
quote! { #name ( #(#arg_tuple),* ) .await }
custom_name
} else {
quote! { #name ( #(#arg_tuple),* ) }
name.to_string()
};

let arg_count = arg_tuple.len();

let return_encode = if method.is_lifecycle() || attrs.manual_reply {
quote! {}
} else {
let return_bytes = match return_length {
0 => quote! { ::candid::utils::encode_one(()).unwrap() },
1 => quote! { ::candid::utils::encode_one(result).unwrap() },
_ => quote! { ::candid::utils::encode_args(result).unwrap() },
};
quote! {
::ic_cdk::api::msg_reply(#return_bytes);
}
};

// On initialization we can actually not receive any input and it's okay, only if
// we don't have any arguments either.
// If the data we receive is not empty, then try to unwrap it as if it's DID.
let arg_decode = if method.is_lifecycle() && arg_count == 0 {
quote! {}
let export_name = if method.is_lifecycle() {
format!("canister_{}", method)
} else if method == MethodType::Query && attrs.composite {
format!("canister_composite_query {function_name}",)
} else {
quote! {
let arg_bytes = ::ic_cdk::api::msg_arg_data();
let ( #( #arg_tuple, )* ) = ::candid::utils::decode_args(&arg_bytes).unwrap(); }
format!("canister_{method} {function_name}")
};
let host_compatible_name = export_name.replace(' ', ".").replace(['-', '<', '>'], "_");

// 2. guard
let guard = if let Some(guard_name) = attrs.guard {
// ic_cdk::api::call::reject calls ic0::msg_reject which is only allowed in update/query
// ic0.msg_reject is only allowed in update/query
if method.is_lifecycle() {
return Err(Error::new(
attr.span(),
Expand All @@ -213,6 +192,78 @@ fn dfn_macro(
quote! {}
};

// 3. decode arguments
let (arg_tuple, _): (Vec<Ident>, Vec<Box<Type>>) =
get_args(method, signature)?.iter().cloned().unzip();
if !method.can_have_args() {
if !arg_tuple.is_empty() {
return Err(Error::new(
Span::call_site(),
format!("#[{}] function cannot have arguments.", method),
));
}
if attrs.decode_with.is_some() {
return Err(Error::new(
attr.span(),
format!(
"#[{}] function cannot have a decode_with attribute.",
method
),
));
}
}
let arg_decode = if let Some(decode_with) = attrs.decode_with {
let decode_with_ident = syn::Ident::new(&decode_with, Span::call_site());
if arg_tuple.len() == 1 {
let arg_one = &arg_tuple[0];
quote! { let #arg_one = #decode_with_ident(); }
} else {
quote! { let ( #( #arg_tuple, )* ) = #decode_with_ident(); }
}
} else if arg_tuple.is_empty() {
quote! {}
} else {
quote! {
let arg_bytes = ::ic_cdk::api::msg_arg_data();
let ( #( #arg_tuple, )* ) = ::candid::utils::decode_args(&arg_bytes).unwrap();
}
};

// 4. function call
let function_call = if signature.asyncness.is_some() {
quote! { #name ( #(#arg_tuple),* ) .await }
} else {
quote! { #name ( #(#arg_tuple),* ) }
};

// 5. return
let return_length = match &signature.output {
ReturnType::Default => 0,
ReturnType::Type(_, ty) => match ty.as_ref() {
Type::Tuple(tuple) => tuple.elems.len(),
_ => 1,
},
};
if method.is_lifecycle() && return_length > 0 {
return Err(Error::new(
Span::call_site(),
format!("#[{}] function cannot have a return value.", method),
));
}
let return_encode = if method.is_lifecycle() || attrs.manual_reply {
quote! {}
} else {
let return_bytes = match return_length {
0 => quote! { ::candid::utils::encode_one(()).unwrap() },
1 => quote! { ::candid::utils::encode_one(result).unwrap() },
_ => quote! { ::candid::utils::encode_args(result).unwrap() },
};
quote! {
::ic_cdk::api::msg_reply(#return_bytes);
}
};

// 6. candid attributes for export_candid!()
let candid_method_attr = if attrs.hidden {
quote! {}
} else {
Expand Down Expand Up @@ -262,9 +313,6 @@ pub(crate) fn ic_update(attr: TokenStream, item: TokenStream) -> Result<TokenStr
dfn_macro(MethodType::Update, attr, item)
}

#[derive(Default, Deserialize)]
struct InitAttributes {}

pub(crate) fn ic_init(attr: TokenStream, item: TokenStream) -> Result<TokenStream, Error> {
dfn_macro(MethodType::Init, attr, item)
}
Expand Down Expand Up @@ -320,8 +368,6 @@ mod test {
fn #fn_name() {
::ic_cdk::setup();
::ic_cdk::spawn(async {
let arg_bytes = ::ic_cdk::api::msg_arg_data();
let () = ::candid::utils::decode_args(&arg_bytes).unwrap();
let result = query();
::ic_cdk::api::msg_reply(::candid::utils::encode_one(()).unwrap());
});
Expand Down Expand Up @@ -359,8 +405,6 @@ mod test {
fn #fn_name() {
::ic_cdk::setup();
::ic_cdk::spawn(async {
let arg_bytes = ::ic_cdk::api::msg_arg_data();
let () = ::candid::utils::decode_args(&arg_bytes).unwrap();
let result = query();
::ic_cdk::api::msg_reply(::candid::utils::encode_one(result).unwrap());
});
Expand Down Expand Up @@ -398,8 +442,6 @@ mod test {
fn #fn_name() {
::ic_cdk::setup();
::ic_cdk::spawn(async {
let arg_bytes = ::ic_cdk::api::msg_arg_data();
let () = ::candid::utils::decode_args(&arg_bytes).unwrap();
let result = query();
::ic_cdk::api::msg_reply(::candid::utils::encode_args(result).unwrap());
});
Expand Down Expand Up @@ -553,8 +595,6 @@ mod test {
fn #fn_name() {
::ic_cdk::setup();
::ic_cdk::spawn(async {
let arg_bytes = ::ic_cdk::api::msg_arg_data();
let () = ::candid::utils::decode_args(&arg_bytes).unwrap();
let result = query();
::ic_cdk::api::msg_reply(::candid::utils::encode_one(()).unwrap());
});
Expand Down

0 comments on commit 3184c9c

Please sign in to comment.