Skip to content

Commit

Permalink
feat(avm/public): user space PublicContext::get_args_hash (#8292)
Browse files Browse the repository at this point in the history
This PR implements `PublicContext::get_args_hash` in user space. We are
still passing the calldata length as a runtime variable until we can get
it at compile time. This requires @Thunkar 's work on `aztec(public)` as
a macro.

Once that is done, we'll pass the hasher as a closure when creating the
PublicContext, i.e.:
```
struct PublicContext {
    hash_getter: fn[(Field,)]() -> Field,
    // ...
}

impl PublicContext {
    pub fn new(..., hash_getter) -> Self {
        // ...
    }

    fn get_args_hash(self) -> Field {
        (self.hash_getter)()
    }
}

// In the aztec(public) macro
comptime let N = get_calldata_length();
let hash_getter = || {
    let mut hasher = ArgsHasher::new();
    let mut fields = std::meta::unquote!(quote { [0; $N] });
    fields = calldata_copy(2 /*or 1*/, N);
    hasher.add_many(fields);
    hasher.hash()
};
let context = PublicContext::new(..., hash_getter);
```
  • Loading branch information
fcarreiro authored and AztecBot committed Sep 11, 2024
1 parent 1c9377d commit 4be77a0
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 7 deletions.
5 changes: 3 additions & 2 deletions aztec/src/context/inputs/public_context_inputs.nr
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ use dep::protocol_types::traits::Empty;

// These inputs will likely go away once the AVM processes 1 public kernel per enqueued call.
struct PublicContextInputs {
args_hash: Field,
// TODO: Remove this structure and get calldata size at compile time.
calldata_length: Field,
is_static_call: bool
}

impl Empty for PublicContextInputs {
fn empty() -> Self {
PublicContextInputs {
args_hash: 0,
calldata_length: 0,
is_static_call: false
}
}
Expand Down
27 changes: 24 additions & 3 deletions aztec/src/context/public_context.nr
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ use dep::protocol_types::traits::{Serialize, Deserialize, Empty};
use dep::protocol_types::abis::function_selector::FunctionSelector;
use crate::context::inputs::public_context_inputs::PublicContextInputs;
use crate::context::gas::GasOpts;
use crate::hash::ArgsHasher;

struct PublicContext {
inputs: PublicContextInputs,
args_hash: Option<Field>
}

impl PublicContext {
pub fn new(inputs: PublicContextInputs) -> Self {
PublicContext { inputs }
PublicContext { inputs, args_hash: Option::none() }
}

pub fn emit_unencrypted_log<T, let N: u32>(_self: &mut Self, log: T) where T: Serialize<N> {
Expand Down Expand Up @@ -130,8 +132,20 @@ impl PublicContext {
fn selector(_self: Self) -> FunctionSelector {
FunctionSelector::from_u32(function_selector())
}
fn get_args_hash(self) -> Field {
self.inputs.args_hash
fn get_args_hash(mut self) -> Field {
if !self.args_hash.is_some() {
let mut hasher = ArgsHasher::new();

// TODO: this should be replaced with the compile-time calldata size.
for i in 0..self.inputs.calldata_length as u32 {
let argn: [Field; 1] = calldata_copy((2 + i) as u32, 1);
hasher.add(argn[0]);
}

self.args_hash = Option::some(hasher.hash());
}

self.args_hash.unwrap()
}
fn transaction_fee(_self: Self) -> Field {
transaction_fee()
Expand Down Expand Up @@ -278,6 +292,10 @@ unconstrained fn call_static<let RET_SIZE: u32>(
call_static_opcode(gas, address, args, function_selector)
}

unconstrained fn calldata_copy<let N: u32>(cdoffset: u32, copy_size: u32) -> [Field; N] {
calldata_copy_opcode(cdoffset, copy_size)
}

unconstrained fn storage_read(storage_slot: Field) -> Field {
storage_read_opcode(storage_slot)
}
Expand Down Expand Up @@ -356,6 +374,9 @@ unconstrained fn l1_to_l2_msg_exists_opcode(msg_hash: Field, msg_leaf_index: Fie
#[oracle(avmOpcodeSendL2ToL1Msg)]
unconstrained fn send_l2_to_l1_msg_opcode(recipient: EthAddress, content: Field) {}

#[oracle(avmOpcodeCalldataCopy)]
unconstrained fn calldata_copy_opcode<let N: u32>(cdoffset: u32, copy_size: u32) -> [Field; N] {}

#[oracle(avmOpcodeCall)]
unconstrained fn call_opcode<let RET_SIZE: u32>(
gas: [Field; 2], // gas allocation: [l2_gas, da_gas]
Expand Down
7 changes: 7 additions & 0 deletions aztec/src/test/helpers/cheatcodes.nr
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ unconstrained pub fn set_msg_sender(msg_sender: AztecAddress) {
oracle_set_msg_sender(msg_sender)
}

unconstrained pub fn set_calldata(calldata: [Field]) {
oracle_set_calldata(calldata)
}

unconstrained pub fn get_msg_sender() -> AztecAddress {
oracle_get_msg_sender()
}
Expand Down Expand Up @@ -187,3 +191,6 @@ unconstrained fn oracle_get_function_selector() -> FunctionSelector {}

#[oracle(setFunctionSelector)]
unconstrained fn oracle_set_function_selector(selector: FunctionSelector) {}

#[oracle(setCalldata)]
unconstrained fn oracle_set_calldata(calldata: [Field]) {}
6 changes: 5 additions & 1 deletion aztec/src/test/helpers/test_environment.nr
Original file line number Diff line number Diff line change
Expand Up @@ -164,18 +164,22 @@ impl TestEnvironment {
let original_fn_selector = cheatcodes::get_function_selector();
let target_address = call_interface.get_contract_address();
let fn_selector = call_interface.get_selector();
let calldata = call_interface.get_args();

cheatcodes::set_fn_selector(fn_selector);
cheatcodes::set_contract_address(target_address);
cheatcodes::set_msg_sender(original_contract_address);
let mut inputs = cheatcodes::get_public_context_inputs();
inputs.args_hash = hash_args(call_interface.get_args());
inputs.calldata_length = call_interface.get_args().len() as Field;
inputs.is_static_call = call_interface.get_is_static();
cheatcodes::set_calldata(calldata);

let result = original_fn(inputs);

cheatcodes::set_fn_selector(original_fn_selector);
cheatcodes::set_contract_address(original_contract_address);
cheatcodes::set_msg_sender(original_msg_sender);
cheatcodes::set_calldata(calldata);
result
}

Expand Down
7 changes: 6 additions & 1 deletion aztec/src/test/helpers/utils.nr
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,22 @@ impl<let N: u32, let M: u32> Deployer<N, M> {
let original_msg_sender = cheatcodes::get_msg_sender();
let original_contract_address = get_contract_address();
let original_fn_selector = cheatcodes::get_function_selector();
let calldata = call_interface.get_args();

cheatcodes::set_fn_selector(call_interface.get_selector());
cheatcodes::set_contract_address(instance.to_address());
cheatcodes::set_msg_sender(original_contract_address);
let mut inputs = cheatcodes::get_public_context_inputs();
inputs.args_hash = hash_args(call_interface.get_args());
inputs.calldata_length = call_interface.get_args().len() as Field;
inputs.is_static_call = call_interface.get_is_static();
cheatcodes::set_calldata(calldata);

let _result: T = original_fn(inputs);

cheatcodes::set_fn_selector(original_fn_selector);
cheatcodes::set_contract_address(original_contract_address);
cheatcodes::set_msg_sender(original_msg_sender);
cheatcodes::set_calldata(calldata);
instance
}

Expand Down

0 comments on commit 4be77a0

Please sign in to comment.