Skip to content

Commit

Permalink
feat: GETCONTRACTINSTANCE and bytecode retrieval perform nullifier me…
Browse files Browse the repository at this point in the history
…mbership checks
  • Loading branch information
dbanks12 committed Dec 6, 2024
1 parent bd64cc5 commit bd963b4
Show file tree
Hide file tree
Showing 17 changed files with 365 additions and 128 deletions.
4 changes: 3 additions & 1 deletion barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,9 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
{
AvmError error = AvmError::NO_ERROR;
// Find the bytecode based on contract address of the public call request
std::vector<uint8_t> bytecode = trace_builder.get_bytecode(public_call_request.contract_address);
// TODO(dbanks12): accept check_membership flag as arg
std::vector<uint8_t> bytecode =
trace_builder.get_bytecode(public_call_request.contract_address, /*check_membership=*/true);

// Set this also on nested call

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ struct ContractInstanceHint {
FF contract_class_id{};
FF initialisation_hash{};
PublicKeysHint public_keys;
NullifierReadTreeHint membership_hint;
};

inline void read(uint8_t const*& it, PublicKeysHint& hint)
Expand All @@ -189,6 +190,7 @@ inline void read(uint8_t const*& it, ContractInstanceHint& hint)
read(it, hint.contract_class_id);
read(it, hint.initialisation_hash);
read(it, hint.public_keys);
read(it, hint.membership_hint);
}

struct AvmContractBytecode {
Expand All @@ -201,7 +203,7 @@ struct AvmContractBytecode {
ContractInstanceHint contract_instance,
ContractClassIdHint contract_class_id_preimage)
: bytecode(std::move(bytecode))
, contract_instance(contract_instance)
, contract_instance(std::move(contract_instance))
, contract_class_id_preimage(contract_class_id_preimage)
{}
AvmContractBytecode(std::vector<uint8_t> bytecode)
Expand Down
75 changes: 45 additions & 30 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,25 +147,24 @@ void AvmTraceBuilder::rollback_to_non_revertible_checkpoint()

std::vector<uint8_t> AvmTraceBuilder::get_bytecode(const FF contract_address, bool check_membership)
{
// uint32_t clk = 0;
// auto clk = static_cast<uint32_t>(main_trace.size()) + 1;
auto clk = static_cast<uint32_t>(main_trace.size()) + 1;

// Find the bytecode based on contract address of the public call request
const AvmContractBytecode bytecode_hint =
*std::ranges::find_if(execution_hints.all_contract_bytecode, [contract_address](const auto& contract) {
return contract.contract_instance.address == contract_address;
});
if (check_membership) {
// NullifierReadTreeHint nullifier_read_hint = bytecode_hint.contract_instance.membership_hint;
//// hinted nullifier should match the specified contract address
// ASSERT(nullifier_read_hint.low_leaf_preimage.nullifier == contract_address);
// bool is_member = merkle_tree_trace_builder.perform_nullifier_read(clk,
// nullifier_read_hint.low_leaf_preimage,
// nullifier_read_hint.low_leaf_index,
// nullifier_read_hint.low_leaf_sibling_path);
//// TODO(dbanks12): handle non-existent bytecode
//// if the contract address nullifier is hinted as "exists", the membership check should agree
// ASSERT(is_member);
NullifierReadTreeHint nullifier_read_hint = bytecode_hint.contract_instance.membership_hint;
// hinted nullifier should match the specified contract address
ASSERT(nullifier_read_hint.low_leaf_preimage.nullifier == contract_address);
bool is_member = merkle_tree_trace_builder.perform_nullifier_read(clk,
nullifier_read_hint.low_leaf_preimage,
nullifier_read_hint.low_leaf_index,
nullifier_read_hint.low_leaf_sibling_path);
// TODO(dbanks12): handle non-existent bytecode
// if the contract address nullifier is hinted as "exists", the membership check should agree
ASSERT(is_member);
}

vinfo("Found bytecode for contract address: ", contract_address);
Expand Down Expand Up @@ -3197,23 +3196,39 @@ AvmError AvmTraceBuilder::op_get_contract_instance(
error = AvmError::CHECK_TAG_ERROR;
}

// Read the contract instance
ContractInstanceHint instance = execution_hints.contract_instance_hints.at(read_address.val);

FF member_value;
switch (chosen_member) {
case ContractInstanceMember::DEPLOYER:
member_value = instance.deployer_addr;
break;
case ContractInstanceMember::CLASS_ID:
member_value = instance.contract_class_id;
break;
case ContractInstanceMember::INIT_HASH:
member_value = instance.initialisation_hash;
break;
default:
member_value = 0;
break;
FF member_value = 0;
bool exists = false;

if (is_ok(error)) {
// Read the contract instance
ContractInstanceHint instance = execution_hints.contract_instance_hints.at(read_address.val);
// nullifier read hint for the contract address
NullifierReadTreeHint nullifier_read_hint = instance.membership_hint;
// hinted nullifier should match the specified contract addrss
exists = nullifier_read_hint.low_leaf_preimage.nullifier == read_address.val;
ASSERT(exists);
bool is_member = merkle_tree_trace_builder.perform_nullifier_read(clk,
nullifier_read_hint.low_leaf_preimage,
nullifier_read_hint.low_leaf_index,
nullifier_read_hint.low_leaf_sibling_path);
// if the contract address nullifier is hinted as "exists", the membership check should agree
ASSERT(is_member == exists);
exists = instance.exists;

switch (chosen_member) {
case ContractInstanceMember::DEPLOYER:
member_value = instance.deployer_addr;
break;
case ContractInstanceMember::CLASS_ID:
member_value = instance.contract_class_id;
break;
case ContractInstanceMember::INIT_HASH:
member_value = instance.initialisation_hash;
break;
default:
member_value = 0;
break;
}
}

// TODO(8603): once instructions can have multiple different tags for writes, write dst as FF and exists as
Expand Down Expand Up @@ -3257,7 +3272,7 @@ AvmError AvmTraceBuilder::op_get_contract_instance(
// TODO(8603): once instructions can have multiple different tags for writes, remove this and do a
// constrained writes
write_to_memory(resolved_dst_offset, member_value, AvmMemoryTag::FF);
write_to_memory(resolved_exists_offset, FF(static_cast<uint32_t>(instance.exists)), AvmMemoryTag::U1);
write_to_memory(resolved_exists_offset, FF(static_cast<uint32_t>(exists)), AvmMemoryTag::U1);

// TODO(dbanks12): compute contract address nullifier from instance preimage and perform membership check

Expand Down
12 changes: 10 additions & 2 deletions yarn-project/circuits.js/src/structs/avm/avm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ export class AvmContractInstanceHint {
public readonly contractClassId: Fr,
public readonly initializationHash: Fr,
public readonly publicKeys: PublicKeys,
public readonly membershipHint: AvmNullifierReadTreeHint = AvmNullifierReadTreeHint.empty(),
) {}
/**
* Serializes the inputs to a buffer.
Expand Down Expand Up @@ -288,7 +289,8 @@ export class AvmContractInstanceHint {
this.deployer.isZero() &&
this.contractClassId.isZero() &&
this.initializationHash.isZero() &&
this.publicKeys.isEmpty()
this.publicKeys.isEmpty() &&
this.membershipHint.isEmpty()
);
}

Expand All @@ -315,6 +317,7 @@ export class AvmContractInstanceHint {
fields.contractClassId,
fields.initializationHash,
fields.publicKeys,
fields.membershipHint,
] as const;
}

Expand All @@ -333,6 +336,7 @@ export class AvmContractInstanceHint {
Fr.fromBuffer(reader),
Fr.fromBuffer(reader),
PublicKeys.fromBuffer(reader),
AvmNullifierReadTreeHint.fromBuffer(reader),
);
}

Expand Down Expand Up @@ -592,7 +596,7 @@ export class AvmNullifierReadTreeHint {
constructor(
public readonly lowLeafPreimage: NullifierLeafPreimage,
public readonly lowLeafIndex: Fr,
public readonly _lowLeafSiblingPath: Fr[],
public _lowLeafSiblingPath: Fr[],
) {
this.lowLeafSiblingPath = new Vector(_lowLeafSiblingPath);
}
Expand Down Expand Up @@ -630,6 +634,10 @@ export class AvmNullifierReadTreeHint {
return new AvmNullifierReadTreeHint(fields.lowLeafPreimage, fields.lowLeafIndex, fields.lowLeafSiblingPath.items);
}

static empty(): AvmNullifierReadTreeHint {
return new AvmNullifierReadTreeHint(NullifierLeafPreimage.empty(), Fr.ZERO, []);
}

/**
* Extracts fields from an instance.
* @param fields - Fields to create the instance from.
Expand Down
2 changes: 2 additions & 0 deletions yarn-project/circuits.js/src/tests/factories.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,7 @@ export function makeAvmBytecodeHints(seed = 0): AvmContractBytecodeHints {
instance.contractClassId,
instance.initializationHash,
instance.publicKeys,
makeAvmNullifierReadTreeHints(seed + 0x2000),
);

const publicBytecodeCommitment = computePublicBytecodeCommitment(packedBytecode);
Expand Down Expand Up @@ -1366,6 +1367,7 @@ export function makeAvmContractInstanceHint(seed = 0): AvmContractInstanceHint {
new Point(new Fr(seed + 0x10), new Fr(seed + 0x11), false),
new Point(new Fr(seed + 0x12), new Fr(seed + 0x13), false),
),
makeAvmNullifierReadTreeHints(seed + 0x1000),
);
}

Expand Down
74 changes: 28 additions & 46 deletions yarn-project/simulator/src/avm/avm_simulator.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import {
GasFees,
GlobalVariables,
PublicDataTreeLeafPreimage,
type PublicFunction,
PublicKeys,
SerializableContractInstance,
} from '@aztec/circuits.js';
Expand All @@ -23,7 +22,8 @@ import { randomInt } from 'crypto';
import { mock } from 'jest-mock-extended';

import { PublicEnqueuedCallSideEffectTrace } from '../public/enqueued_call_side_effect_trace.js';
import { type WorldStateDB } from '../public/public_db_sources.js';
import { MockedAvmTestContractDataSource } from '../public/fixtures/index.js';
import { WorldStateDB } from '../public/public_db_sources.js';
import { type PublicSideEffectTraceInterface } from '../public/side_effect_trace_interface.js';
import { type AvmContext } from './avm_context.js';
import { type AvmExecutionEnvironment } from './avm_execution_environment.js';
Expand Down Expand Up @@ -127,46 +127,19 @@ describe('AVM simulator: transpiled Noir contracts', () => {
const globals = GlobalVariables.empty();
globals.timestamp = TIMESTAMP;

const bytecode = getAvmTestContractBytecode('public_dispatch');
const fnSelector = getAvmTestContractFunctionSelector('public_dispatch');
const publicFn: PublicFunction = { bytecode, selector: fnSelector };
const contractClass = makeContractClassPublic(0, publicFn);
const contractInstance = makeContractInstanceFromClassId(contractClass.id);

// The values here should match those in getContractInstance test case
const instanceGet = new SerializableContractInstance({
version: 1,
salt: new Fr(0x123),
deployer: AztecAddress.fromNumber(0x456),
contractClassId: new Fr(0x789),
initializationHash: new Fr(0x101112),
publicKeys: new PublicKeys(
new Point(new Fr(0x131415), new Fr(0x161718), false),
new Point(new Fr(0x192021), new Fr(0x222324), false),
new Point(new Fr(0x252627), new Fr(0x282930), false),
new Point(new Fr(0x313233), new Fr(0x343536), false),
),
}).withAddress(contractInstance.address);
const worldStateDB = mock<WorldStateDB>();
const tmp = openTmpStore();
const telemetryClient = new NoopTelemetryClient();
const merkleTree = await (await MerkleTrees.new(tmp, telemetryClient)).fork();
worldStateDB.getMerkleInterface.mockReturnValue(merkleTree);

worldStateDB.getContractInstance
.mockResolvedValueOnce(contractInstance)
.mockResolvedValueOnce(instanceGet) // test gets deployer
.mockResolvedValueOnce(instanceGet) // test gets class id
.mockResolvedValueOnce(instanceGet) // test gets init hash
.mockResolvedValue(contractInstance);
worldStateDB.getContractClass.mockResolvedValue(contractClass);

const storageValue = new Fr(5);
mockStorageRead(worldStateDB, storageValue);
const telemetry = new NoopTelemetryClient();
const merkleTrees = await (await MerkleTrees.new(openTmpStore(), telemetry)).fork();
const contractDataSource = new MockedAvmTestContractDataSource();
const worldStateDB = new WorldStateDB(merkleTrees, contractDataSource);

const contractInstance = contractDataSource.contractInstance;
await merkleTrees.batchInsert(MerkleTreeId.NULLIFIER_TREE, [contractInstance.address.toBuffer()], 0);

const trace = mock<PublicSideEffectTraceInterface>();
const merkleTrees = await AvmEphemeralForest.create(worldStateDB.getMerkleInterface());
const persistableState = initPersistableStateManager({ worldStateDB, trace, merkleTrees });
const nestedTrace = mock<PublicSideEffectTraceInterface>();
mockTraceFork(trace, nestedTrace);
const ephemeralTrees = await AvmEphemeralForest.create(worldStateDB.getMerkleInterface());
const persistableState = initPersistableStateManager({ worldStateDB, trace, merkleTrees: ephemeralTrees });
const environment = initExecutionEnvironment({
functionSelector,
calldata,
Expand All @@ -176,10 +149,6 @@ describe('AVM simulator: transpiled Noir contracts', () => {
});
const context = initContext({ env: environment, persistableState });

const nestedTrace = mock<PublicSideEffectTraceInterface>();
mockTraceFork(trace, nestedTrace);
mockGetBytecode(worldStateDB, bytecode);

// First we simulate (though it's not needed in this simple case).
const simulator = new AvmSimulator(context);
const results = await simulator.execute();
Expand Down Expand Up @@ -591,7 +560,7 @@ describe('AVM simulator: transpiled Noir contracts', () => {
const bytecode = getAvmTestContractBytecode('nullifier_exists');

if (exists) {
mockNullifierExists(worldStateDB, leafIndex, value0);
mockNullifierExists(worldStateDB, leafIndex, siloedNullifier0);
}

const results = await new AvmSimulator(context).executeBytecode(bytecode);
Expand Down Expand Up @@ -883,7 +852,14 @@ describe('AVM simulator: transpiled Noir contracts', () => {
new Point(new Fr(0x313233), new Fr(0x343536), false),
),
});
mockGetContractInstance(worldStateDB, contractInstance.withAddress(address));
const contractInstanceWithAddress = contractInstance.withAddress(address);
// mock once per enum value (deployer, classId, initializationHash)
mockGetContractInstance(worldStateDB, contractInstanceWithAddress);
mockGetContractInstance(worldStateDB, contractInstanceWithAddress);
mockGetContractInstance(worldStateDB, contractInstanceWithAddress);
mockNullifierExists(worldStateDB, contractInstanceWithAddress.address.toField());
mockNullifierExists(worldStateDB, contractInstanceWithAddress.address.toField());
mockNullifierExists(worldStateDB, contractInstanceWithAddress.address.toField());

const bytecode = getAvmTestContractBytecode('test_get_contract_instance');

Expand Down Expand Up @@ -952,6 +928,7 @@ describe('AVM simulator: transpiled Noir contracts', () => {
mockGetContractClass(worldStateDB, contractClass);
const contractInstance = makeContractInstanceFromClassId(contractClass.id);
mockGetContractInstance(worldStateDB, contractInstance);
mockNullifierExists(worldStateDB, contractInstance.address.toField());

const nestedTrace = mock<PublicSideEffectTraceInterface>();
mockTraceFork(trace, nestedTrace);
Expand All @@ -977,6 +954,7 @@ describe('AVM simulator: transpiled Noir contracts', () => {
mockGetContractClass(worldStateDB, contractClass);
const contractInstance = makeContractInstanceFromClassId(contractClass.id);
mockGetContractInstance(worldStateDB, contractInstance);
mockNullifierExists(worldStateDB, contractInstance.address.toField());

const nestedTrace = mock<PublicSideEffectTraceInterface>();
mockTraceFork(trace, nestedTrace);
Expand Down Expand Up @@ -1005,6 +983,7 @@ describe('AVM simulator: transpiled Noir contracts', () => {
mockGetContractClass(worldStateDB, contractClass);
const contractInstance = makeContractInstanceFromClassId(contractClass.id);
mockGetContractInstance(worldStateDB, contractInstance);
mockNullifierExists(worldStateDB, contractInstance.address.toField());

mockTraceFork(trace);

Expand All @@ -1029,6 +1008,7 @@ describe('AVM simulator: transpiled Noir contracts', () => {
mockGetContractClass(worldStateDB, contractClass);
const contractInstance = makeContractInstanceFromClassId(contractClass.id);
mockGetContractInstance(worldStateDB, contractInstance);
mockNullifierExists(worldStateDB, contractInstance.address.toField());

const nestedTrace = mock<PublicSideEffectTraceInterface>();
mockTraceFork(trace, nestedTrace);
Expand Down Expand Up @@ -1060,6 +1040,7 @@ describe('AVM simulator: transpiled Noir contracts', () => {
mockGetContractClass(worldStateDB, contractClass);
const contractInstance = makeContractInstanceFromClassId(contractClass.id);
mockGetContractInstance(worldStateDB, contractInstance);
mockNullifierExists(worldStateDB, contractInstance.address.toField());

mockTraceFork(trace);

Expand All @@ -1084,6 +1065,7 @@ describe('AVM simulator: transpiled Noir contracts', () => {
mockGetContractClass(worldStateDB, contractClass);
const contractInstance = makeContractInstanceFromClassId(contractClass.id);
mockGetContractInstance(worldStateDB, contractInstance);
mockNullifierExists(worldStateDB, contractInstance.address.toField());

mockTraceFork(trace);

Expand Down
Loading

0 comments on commit bd963b4

Please sign in to comment.