Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] feat: contract entrypoint dispatch function #8726

Closed
wants to merge 15 commits into from
Closed
134 changes: 134 additions & 0 deletions noir-projects/aztec-nr/aztec/src/macros/dispatch/mod.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use protocol_types::abis::function_selector::FunctionSelector;
use crate::context::inputs::public_context_inputs::PublicContextInputs;
use super::utils::compute_fn_selector;

/// Returns an `fn public_dispatch(...)` function for the given module that's assumed to be an Aztec contract.
asterite marked this conversation as resolved.
Show resolved Hide resolved
pub comptime fn generate_public_dispatch(m: Module) -> Quoted {
let functions = m.functions();
let functions = functions.filter(|function: FunctionDefinition| function.has_named_attribute("public"));

let public_context_inputs = get_type::<PublicContextInputs>();
asterite marked this conversation as resolved.
Show resolved Hide resolved
let unit = get_type::<()>();
let initial_offset = size_in_fields(public_context_inputs) + 1; // +1 for the Field selector

let ifs = functions.map(
|function: FunctionDefinition| {
let name = function.name();
let parameters = function.parameters();
let return_type = function.return_type();

let selector: Field = compute_fn_selector(function);

let mut parameter_index = 0;
let mut offset = initial_offset;

let reads = parameters.map(|param: (Quoted, Type)| {
// Skip the `PublicContextInputs` argument as we already have that
let read = if parameter_index == 0 {
quote {}
} else {
let param_type = param.1;
let param_size = size_in_fields(param_type);
let param_name = f"arg{parameter_index}".quoted_contents();
let read = quote { let $param_name = dep::aztec::protocol_types::traits::Deserialize::deserialize(dep::aztec::context::public_context::calldata_copy($offset, $param_size)); };
offset += param_size;
quote { $read }
};
parameter_index += 1;
read
});
let read = reads.join(quote { });

let mut args = &[];
for parameter_index in 0..parameters.len() {
let param_name = f"arg{parameter_index}".quoted_contents();
args = args.push_back(quote { $param_name });
}

let args = args.join(quote { , });
let call = quote { $name($args) };

let return_code = if return_type == unit {
quote { $call }
} else {
quote {
let return_value = dep::aztec::protocol_types::traits::Serialize::serialize($call);
dep::aztec::context::public_context::avm_return(return_value);
}
};

let if_ = quote {
if selector == $selector {
$read
$return_code
}
};
if_
}
);

let ifs = ifs.push_back(quote { { panic(f"Unknown selector") } });
let dispatch = ifs.join(quote { else });

let body = quote {
pub fn public_dispatch(arg0: $public_context_inputs, selector: Field) {
fcarreiro marked this conversation as resolved.
Show resolved Hide resolved
$dispatch
}
};

println(body);
asterite marked this conversation as resolved.
Show resolved Hide resolved

body
}

comptime fn size_in_fields(typ: Type) -> u32 {
if typ.as_slice().is_some() {
panic(f"Can't determine size in fields of Slice type")
} else {
let size = array_size_in_fields(typ);
let size = size.or_else(|| struct_size_in_fields(typ));
let size = size.or_else(|| tuple_size_in_fields(typ));
size.unwrap_or(1)
}
}

comptime fn array_size_in_fields(typ: Type) -> Option<u32> {
typ.as_array().and_then(
|typ: (Type, Type)| {
let (typ, element_size) = typ;
element_size.as_constant().map(|x: u32| {
x * size_in_fields(typ)
})
}
)
}

comptime fn struct_size_in_fields(typ: Type) -> Option<u32> {
typ.as_struct().map(
|typ: (StructDefinition, [Type])| {
let struct_type = typ.0;
let mut size = 0;
for field in struct_type.fields() {
size += size_in_fields(field.1);
}
size
}
)
}

comptime fn tuple_size_in_fields(typ: Type) -> Option<u32> {
typ.as_tuple().map(
|types: [Type]| {
let mut size = 0;
for typ in types {
size += size_in_fields(typ);
}
size
}
)
}

comptime fn get_type<T>() -> Type {
let t: T = std::mem::zeroed();
std::meta::type_of(t)
}
7 changes: 4 additions & 3 deletions noir-projects/aztec-nr/aztec/src/macros/mod.nr
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod dispatch;
mod functions;
mod utils;
mod notes;
Expand All @@ -10,26 +11,26 @@ use notes::{NOTES, generate_note_export};

use functions::transform_unconstrained;
use utils::module_has_storage;
use dispatch::generate_public_dispatch;

/// Marks a contract as an Aztec contract, generating the interfaces for its functions and notes, as well as injecting
/// the `compute_note_hash_and_optionally_a_nullifier` function PXE requires in order to validate notes.
pub comptime fn aztec(m: Module) -> Quoted {
let interface = generate_contract_interface(m);

let unconstrained_functions = m.functions().filter(
| f: FunctionDefinition | f.is_unconstrained() & !f.has_named_attribute("test") & !f.has_named_attribute("public")
);
for f in unconstrained_functions {
transform_unconstrained(f);
}

let compute_note_hash_and_optionally_a_nullifier = generate_compute_note_hash_and_optionally_a_nullifier();
let note_exports = generate_note_exports();

let public_dispatch = generate_public_dispatch(m);
quote {
$note_exports
$interface
$compute_note_hash_and_optionally_a_nullifier
$public_dispatch
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -567,3 +567,15 @@ contract AvmTest {
let _ = nested_static_call_to_add(inputs, 1, 2);
}
}

mod tests {
use super::AvmTest;
use aztec::context::inputs::public_context_inputs::PublicContextInputs;

#[test]
fn contract_has_public_dispatch() {
let selector = 0;
let inputs = PublicContextInputs::empty();
AvmTest::public_dispatch(inputs, selector);
}
fcarreiro marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,99 @@ impl Deserialize<FIELD_SERIALIZED_LEN> for Field {
fields[0]
}
}

impl <let N: u32> Serialize<N> for [u8; N] {
fn serialize(self) -> [Field; N] {
self.map(|value| value as Field)
}
}

impl <let N: u32> Deserialize<N> for [u8; N] {
fn deserialize(fields: [Field; N]) -> Self {
fields.map(|value| value as u8)
}
}

impl <let N: u32> Serialize<N> for [u16; N] {
fn serialize(self) -> [Field; N] {
self.map(|value| value as Field)
}
}

impl <let N: u32> Deserialize<N> for [u16; N] {
fn deserialize(fields: [Field; N]) -> Self {
fields.map(|value| value as u16)
}
}

impl <let N: u32> Serialize<N> for [u32; N] {
fn serialize(self) -> [Field; N] {
self.map(|value| value as Field)
}
}

impl <let N: u32> Deserialize<N> for [u32; N] {
fn deserialize(fields: [Field; N]) -> Self {
fields.map(|value| value as u32)
}
}

impl <let N: u32> Serialize<N> for [u64; N] {
fn serialize(self) -> [Field; N] {
self.map(|value| value as Field)
}
}

impl <let N: u32> Deserialize<N> for [u64; N] {
fn deserialize(fields: [Field; N]) -> Self {
fields.map(|value| value as u64)
}
}

impl <let N: u32> Serialize<N> for [i8; N] {
fn serialize(self) -> [Field; N] {
self.map(|value| value as Field)
}
}

impl <let N: u32> Deserialize<N> for [i8; N] {
fn deserialize(fields: [Field; N]) -> Self {
fields.map(|value| value as i8)
}
}

impl <let N: u32> Serialize<N> for [i16; N] {
fn serialize(self) -> [Field; N] {
self.map(|value| value as Field)
}
}

impl <let N: u32> Deserialize<N> for [i16; N] {
fn deserialize(fields: [Field; N]) -> Self {
fields.map(|value| value as i16)
}
}

impl <let N: u32> Serialize<N> for [i32; N] {
fn serialize(self) -> [Field; N] {
self.map(|value| value as Field)
}
}

impl <let N: u32> Deserialize<N> for [i32; N] {
fn deserialize(fields: [Field; N]) -> Self {
fields.map(|value| value as i32)
}
}

impl <let N: u32> Serialize<N> for [i64; N] {
fn serialize(self) -> [Field; N] {
self.map(|value| value as Field)
}
}

impl <let N: u32> Deserialize<N> for [i64; N] {
fn deserialize(fields: [Field; N]) -> Self {
fields.map(|value| value as i64)
}
}
Loading