Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: validate counters #6365

Merged
merged 16 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions noir-projects/aztec-nr/aztec/src/context/private_context.nr
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,10 @@ impl ContextInterface for PrivateContext {

impl PrivateContext {
pub fn new(inputs: PrivateContextInputs, args_hash: Field) -> PrivateContext {
let side_effect_counter = inputs.start_side_effect_counter;
let mut min_revertible_side_effect_counter = 0;
if is_empty(inputs.call_context.msg_sender) {
min_revertible_side_effect_counter = side_effect_counter;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

min_revertible_side_effect_counter is only propagated to the public inputs in init kernel circuit. And it's on the entrypoint to decide if it has a value or not. It's not necessary to set it to be the same as the call's start counter if it's the first call (empty msg_sender).

}
PrivateContext {
inputs,
side_effect_counter,
min_revertible_side_effect_counter,
side_effect_counter: inputs.start_side_effect_counter + 1,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the side effects should have unique counters, including the start and end counter for a call, and all the data emitted within.

min_revertible_side_effect_counter: 0,
is_fee_payer: false,
args_hash,
return_hash: 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl PublicContext {
pub fn new(inputs: PublicContextInputs, args_hash: Field) -> PublicContext {
PublicContext {
inputs,
side_effect_counter: inputs.start_side_effect_counter,
side_effect_counter: inputs.start_side_effect_counter + 1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you confirm that we are doing this right in the AVM? (i.e., the initialization of side effect counters, the increment order [whether we use the current and then increment or vice-versa], calculation of end counter, and passing these across enqueued calls). I'm not sure the tests in master would currently exercise all this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point! Modified here to be consistent with the private land.
We are not checking the counters in public kernel. And it won't necessary have to +1 for AVM. It's incremented by 1 because it's easier for private kernel to validate if the counters are strictly increasing.

args_hash,
return_hash: 0,
nullifier_read_requests: BoundedVec::new(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use dep::types::{
abis::{
call_context::CallContext, call_request::CallRequest, private_call_stack_item::PrivateCallStackItem,
private_kernel::private_call_data::PrivateCallData
private_kernel::private_call_data::PrivateCallData, side_effect::Ordered
},
address::{AztecAddress, PartialAddress}, contract_class_id::ContractClassId,
hash::{private_functions_root_from_siblings, stdlib_recursion_verification_key_compress_native_vk},
Expand Down Expand Up @@ -50,9 +50,71 @@ fn validate_call_request(request: CallRequest, hash: Field, caller: PrivateCallS
}
}

fn validate_call_requests<N>(call_requests: [CallRequest; N], hashes: [Field; N], caller: PrivateCallStackItem) {
fn validate_incrementing_counters_within_range<T, N>(
counter_start: u32,
counter_end: u32,
items: [T; N],
num_items: u64
) where T: Ordered {
let mut prev_counter = counter_start;
let mut should_check = true;
for i in 0..N {
validate_call_request(call_requests[i], hashes[i], caller);
should_check &= i != num_items;
if should_check {
let item = items[i];
assert(
item.counter() > prev_counter, "counter must be larger than the counter of the previous item"
);
prev_counter = item.counter();
}
}
assert(prev_counter < counter_end, "counter must be smaller than the end counter of the call");
}

fn validate_incrementing_counter_ranges_within_range<N>(
counter_start: u32,
counter_end: u32,
items: [CallRequest; N],
num_items: u64
) {
let mut prev_counter = counter_start;
let mut should_check = true;
for i in 0..N {
should_check &= i != num_items;
if should_check {
let item = items[i];
assert(
item.start_side_effect_counter > prev_counter, "start counter must be larger than the end counter of the previous call"
);
assert(
item.end_side_effect_counter > item.start_side_effect_counter, "nested call has incorrect counter range"
);
prev_counter = item.end_side_effect_counter;
}
}
assert(
prev_counter < counter_end, "end counter must be smaller than the end counter of the parent call"
);
}

fn validate_split_call_requests<N>(
min_revertible_side_effect_counter: u32,
first_revertible_call_request_index: u64,
call_requests: [CallRequest; N],
num_call_requests: u64
) {
if first_revertible_call_request_index != 0 {
let last_non_revertible_call_request_index = first_revertible_call_request_index - 1;
let call_request = call_requests[last_non_revertible_call_request_index];
assert(
min_revertible_side_effect_counter > call_request.end_side_effect_counter, "min_revertible_side_effect_counter must be greater than the end counter of the last non revertible call"
);
}
if first_revertible_call_request_index != num_call_requests {
let call_request = call_requests[first_revertible_call_request_index];
assert(
min_revertible_side_effect_counter <= call_request.start_side_effect_counter, "min_revertible_side_effect_counter must be less than or equal to the start counter of the first revertible call"
);
}
}

Expand Down Expand Up @@ -86,6 +148,34 @@ impl PrivateCallDataValidator {
self.validate_private_call_requests();
self.validate_public_call_requests();
self.validate_teardown_call_request();
self.validate_counters();
}

pub fn validate_as_first_call(
self,
first_revertible_private_call_request_index: u64,
first_revertible_public_call_request_index: u64
) {
let public_inputs = self.data.call_stack_item.public_inputs;
let call_context = public_inputs.call_context;
assert(call_context.is_delegate_call == false, "Users cannot make a delegatecall");
assert(call_context.is_static_call == false, "Users cannot make a static call");

let min_revertible_side_effect_counter = public_inputs.min_revertible_side_effect_counter;
// No need to check that the min_revertible_side_effect_counter falls in the counter range of the private call.
// It is valid as long as it does not fall in the middle of any nested call.
validate_split_call_requests(
min_revertible_side_effect_counter,
first_revertible_private_call_request_index,
self.data.private_call_stack,
self.array_lengths.private_call_stack_hashes
);
validate_split_call_requests(
min_revertible_side_effect_counter,
first_revertible_public_call_request_index,
self.data.public_call_stack,
self.array_lengths.public_call_stack_hashes
);
}

// Confirm that the TxRequest (user's intent) matches the private call being executed.
Expand All @@ -103,11 +193,6 @@ impl PrivateCallDataValidator {
assert_eq(
tx_request.tx_context, call_stack_item.public_inputs.tx_context, "tx_context in tx_request must match tx_context in call_stack_item"
);

// If checking against TxRequest, it must be the first call, which has the following restrictions.
let call_context = call_stack_item.public_inputs.call_context;
assert(call_context.is_delegate_call == false, "Users cannot make a delegatecall");
assert(call_context.is_static_call == false, "Users cannot make a static call");
}

pub fn validate_against_call_request(self, request: CallRequest) {
Expand Down Expand Up @@ -205,19 +290,19 @@ impl PrivateCallDataValidator {
}

fn validate_private_call_requests(self) {
validate_call_requests(
self.data.private_call_stack,
self.data.call_stack_item.public_inputs.private_call_stack_hashes,
self.data.call_stack_item
);
let call_requests = self.data.private_call_stack;
let hashes = self.data.call_stack_item.public_inputs.private_call_stack_hashes;
for i in 0..call_requests.len() {
validate_call_request(call_requests[i], hashes[i], self.data.call_stack_item);
}
}

fn validate_public_call_requests(self) {
validate_call_requests(
self.data.public_call_stack,
self.data.call_stack_item.public_inputs.public_call_stack_hashes,
self.data.call_stack_item
);
let call_requests = self.data.public_call_stack;
let hashes = self.data.call_stack_item.public_inputs.public_call_stack_hashes;
for i in 0..call_requests.len() {
validate_call_request(call_requests[i], hashes[i], self.data.call_stack_item);
}
}

fn validate_teardown_call_request(self) {
Expand All @@ -227,4 +312,79 @@ impl PrivateCallDataValidator {
self.data.call_stack_item
);
}

fn validate_counters(self) {
let public_inputs = self.data.call_stack_item.public_inputs;
let counter_start = public_inputs.start_side_effect_counter;
let counter_end = public_inputs.end_side_effect_counter;

assert(counter_start < counter_end, "private call has incorrect counter range");

validate_incrementing_counters_within_range(
counter_start,
counter_end,
public_inputs.note_hash_read_requests,
self.array_lengths.note_hash_read_requests
);
validate_incrementing_counters_within_range(
counter_start,
counter_end,
public_inputs.nullifier_read_requests,
self.array_lengths.nullifier_read_requests
);
validate_incrementing_counters_within_range(
counter_start,
counter_end,
public_inputs.new_note_hashes,
self.array_lengths.new_note_hashes
);
validate_incrementing_counters_within_range(
counter_start,
counter_end,
public_inputs.new_nullifiers,
self.array_lengths.new_nullifiers
);
validate_incrementing_counters_within_range(
counter_start,
counter_end,
public_inputs.new_l2_to_l1_msgs,
self.array_lengths.new_l2_to_l1_msgs
);
validate_incrementing_counters_within_range(
counter_start,
counter_end,
public_inputs.encrypted_logs_hashes,
self.array_lengths.encrypted_logs_hashes
);
validate_incrementing_counters_within_range(
counter_start,
counter_end,
public_inputs.unencrypted_logs_hashes,
self.array_lengths.unencrypted_logs_hashes
);
validate_incrementing_counter_ranges_within_range(
counter_start,
counter_end,
self.data.private_call_stack,
self.array_lengths.private_call_stack_hashes
);
validate_incrementing_counter_ranges_within_range(
counter_start,
counter_end,
self.data.public_call_stack,
self.array_lengths.public_call_stack_hashes
);

let teardown_call_request_count = if self.data.public_teardown_call_request.hash == 0 {
0
} else {
1
};
validate_incrementing_counter_ranges_within_range(
counter_start,
counter_end,
[self.data.public_teardown_call_request],
teardown_call_request_count
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use dep::types::{

struct PrivateKernelInitHints {
note_hash_nullifier_counters: [u32; MAX_NEW_NOTE_HASHES_PER_CALL],
first_revertible_private_call_request_index: u64,
first_revertible_public_call_request_index: u64
}

// Initialization struct for private inputs to the private kernel
Expand All @@ -29,6 +31,10 @@ impl PrivateKernelInitCircuitPrivateInputs {

let privateCallDataValidator = PrivateCallDataValidator::new(self.private_call);
privateCallDataValidator.validate();
privateCallDataValidator.validate_as_first_call(
self.hints.first_revertible_private_call_request_index,
self.hints.first_revertible_public_call_request_index
);
privateCallDataValidator.validate_against_tx_request(self.tx_request);

let private_call_public_inputs = self.private_call.call_stack_item.public_inputs;
Expand Down Expand Up @@ -67,7 +73,11 @@ mod tests {
pub fn new() -> Self {
let private_call = PrivateCallDataBuilder::new();
let tx_request = private_call.build_tx_request();
let hints = PrivateKernelInitHints { note_hash_nullifier_counters: [0; MAX_NEW_NOTE_HASHES_PER_CALL] };
let hints = PrivateKernelInitHints {
note_hash_nullifier_counters: [0; MAX_NEW_NOTE_HASHES_PER_CALL],
first_revertible_private_call_request_index: 0,
first_revertible_public_call_request_index: 0
};

PrivateKernelInitInputsBuilder { tx_request, private_call, hints }
}
Expand All @@ -92,13 +102,13 @@ mod tests {
let encrypted_log_preimages_length = [100, 75];
let unencrypted_logs_hashes = [26, 46];
let unencrypted_log_preimages_length = [50, 25];
builder.private_call.set_encrypted_logs(encrypted_logs_hashes[0], encrypted_log_preimages_length[0]);
builder.private_call.set_unencrypted_logs(
builder.private_call.public_inputs.add_encrypted_log(encrypted_logs_hashes[0], encrypted_log_preimages_length[0]);
builder.private_call.public_inputs.add_unencrypted_log(
unencrypted_logs_hashes[0],
unencrypted_log_preimages_length[0]
);
builder.private_call.set_encrypted_logs(encrypted_logs_hashes[1], encrypted_log_preimages_length[1]);
builder.private_call.set_unencrypted_logs(
builder.private_call.public_inputs.add_encrypted_log(encrypted_logs_hashes[1], encrypted_log_preimages_length[1]);
builder.private_call.public_inputs.add_unencrypted_log(
unencrypted_logs_hashes[1],
unencrypted_log_preimages_length[1]
);
Expand Down Expand Up @@ -133,7 +143,7 @@ mod tests {
#[test]
fn propagate_max_block_number_request() {
let mut builder = PrivateKernelInitInputsBuilder::new();
builder.private_call.set_tx_max_block_number(42);
builder.private_call.public_inputs.set_tx_max_block_number(42);
let public_inputs = builder.execute();

assert_eq(public_inputs.validation_requests.for_rollup.max_block_number.unwrap(), 42);
Expand All @@ -144,22 +154,20 @@ mod tests {
let mut builder = PrivateKernelInitInputsBuilder::new();
let storage_contract_address = builder.private_call.public_inputs.call_context.storage_contract_address;

let request_0 = ReadRequest { value: 123, counter: 4567 };
builder.private_call.public_inputs.note_hash_read_requests.push(request_0);
let request_1 = ReadRequest { value: 777888, counter: 90 };
builder.private_call.public_inputs.note_hash_read_requests.push(request_1);
builder.private_call.public_inputs.append_note_hash_read_requests(2);
let new_read_requests = builder.private_call.public_inputs.note_hash_read_requests.storage;

let public_inputs = builder.execute();

let end_note_hash_read_requests = public_inputs.validation_requests.note_hash_read_requests;
assert_eq(array_length(end_note_hash_read_requests), 2);

let request = end_note_hash_read_requests[0];
assert_eq(request.read_request, request_0);
assert_eq(request.read_request, new_read_requests[0]);
assert_eq(request.contract_address, storage_contract_address);

let request = end_note_hash_read_requests[1];
assert_eq(request.read_request, request_1);
assert_eq(request.read_request, new_read_requests[1]);
assert_eq(request.contract_address, storage_contract_address);
}

Expand All @@ -168,22 +176,20 @@ mod tests {
let mut builder = PrivateKernelInitInputsBuilder::new();
let storage_contract_address = builder.private_call.public_inputs.call_context.storage_contract_address;

let request_0 = ReadRequest { value: 123, counter: 4567 };
builder.private_call.public_inputs.nullifier_read_requests.push(request_0);
let request_1 = ReadRequest { value: 777888, counter: 90 };
builder.private_call.public_inputs.nullifier_read_requests.push(request_1);
builder.private_call.public_inputs.append_nullifier_read_requests(2);
let requests = builder.private_call.public_inputs.nullifier_read_requests.storage;

let public_inputs = builder.execute();

let end_nullifier_read_requests = public_inputs.validation_requests.nullifier_read_requests;
assert_eq(array_length(end_nullifier_read_requests), 2);

let request = end_nullifier_read_requests[0];
assert_eq(request.read_request, request_0);
assert_eq(request.read_request, requests[0]);
assert_eq(request.contract_address, storage_contract_address);

let request = end_nullifier_read_requests[1];
assert_eq(request.read_request, request_1);
assert_eq(request.read_request, requests[1]);
assert_eq(request.contract_address, storage_contract_address);
}

Expand Down
Loading
Loading