From e1f3e320f15003282ca5b5ea707471cfcd1b6354 Mon Sep 17 00:00:00 2001 From: Rahul Kothari Date: Mon, 29 Apr 2024 22:40:56 +0400 Subject: [PATCH 01/42] chore(docs): fix migration notes (#6083) --- docs/docs/misc/migration_notes.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/docs/misc/migration_notes.md b/docs/docs/misc/migration_notes.md index 9848ff58b557..6a93becce8c1 100644 --- a/docs/docs/misc/migration_notes.md +++ b/docs/docs/misc/migration_notes.md @@ -8,13 +8,13 @@ Aztec is in full-speed development. Literally every version breaks compatibility ## 0.36.0 -## `FieldNote` removed +### `FieldNote` removed `FieldNote` only existed for testing purposes, and was not a note type that should be used in any real application. Its name unfortunately led users to think that it was a note type suitable to store a `Field` value, which it wasn't. If using `FieldNote`, you most likely want to use `ValueNote` instead, which has both randomness for privacy and an owner for proper nullification. -## `SlowUpdatesTree` replaced for `SharedMutable` +### `SlowUpdatesTree` replaced for `SharedMutable` The old `SlowUpdatesTree` contract and libraries have been removed from the codebase, use the new `SharedMutable` library instead. This will require that you add a global variable specifying a delay in blocks for updates, and replace the slow updates tree state variable with `SharedMutable` variables. @@ -36,7 +36,7 @@ Reading from `SharedMutable` is much simpler, all that's required is to call `ge Finally, you can remove all capsule usage on the client code or tests, since those are no longer required when working with `SharedMutable`. -## [Aztec.nr & js] Portal addresses +### [Aztec.nr & js] Portal addresses Deployments have been modified. No longer are portal addresses treated as a special class, being immutably set on creation of a contract. They are no longer passed in differently compared to the other variables and instead should be implemented using usual storage by those who require it. One should use the storage that matches the usecase - likely shared storage to support private and public. From 6065a6c4157a2d356964f4c5476425da55e09728 Mon Sep 17 00:00:00 2001 From: Lasse Herskind <16536249+LHerskind@users.noreply.github.com> Date: Mon, 29 Apr 2024 23:14:09 +0100 Subject: [PATCH 02/42] test: refactor public cross chain tests for speed (#6082) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixing #6055, time goes from 400 seconds -> 150 seconds. Minor issues when dealing with the snapshots across instances as the cross chain harness is taking the pxe and node as args for its constructor so if going directly from snapshot could run into issue where it is connected to a different instance or the like. Also ran into a deserialization issue for the aztec address, unclear why as it seems like it have been registered 🤷 --- .../token_portal/typescript_glue_code.md | 2 +- .../blacklist_token_contract_test.ts | 2 - .../e2e_public_cross_chain_messaging.test.ts | 412 ------------------ .../deposits.test.ts | 161 +++++++ .../failure_cases.test.ts | 68 +++ .../l1_to_l2.test.ts | 122 ++++++ .../l2_to_l1.test.ts | 116 +++++ ...lic_cross_chain_messaging_contract_test.ts | 225 ++++++++++ 8 files changed, 693 insertions(+), 415 deletions(-) delete mode 100644 yarn-project/end-to-end/src/e2e_public_cross_chain_messaging.test.ts create mode 100644 yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/deposits.test.ts create mode 100644 yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/failure_cases.test.ts create mode 100644 yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/l1_to_l2.test.ts create mode 100644 yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/l2_to_l1.test.ts create mode 100644 yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/public_cross_chain_messaging_contract_test.ts diff --git a/docs/docs/developers/tutorials/token_portal/typescript_glue_code.md b/docs/docs/developers/tutorials/token_portal/typescript_glue_code.md index a81c58e1a590..592ca263f2f7 100644 --- a/docs/docs/developers/tutorials/token_portal/typescript_glue_code.md +++ b/docs/docs/developers/tutorials/token_portal/typescript_glue_code.md @@ -109,7 +109,7 @@ This fetches the wallets from the sandbox and deploys our cross chain harness on ## Public flow test -#include_code e2e_public_cross_chain /yarn-project/end-to-end/src/e2e_public_cross_chain_messaging.test.ts typescript +#include_code e2e_public_cross_chain /yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/deposits.test.ts typescript ## Running the test diff --git a/yarn-project/end-to-end/src/e2e_blacklist_token_contract/blacklist_token_contract_test.ts b/yarn-project/end-to-end/src/e2e_blacklist_token_contract/blacklist_token_contract_test.ts index c0872969c69c..afbb66309f82 100644 --- a/yarn-project/end-to-end/src/e2e_blacklist_token_contract/blacklist_token_contract_test.ts +++ b/yarn-project/end-to-end/src/e2e_blacklist_token_contract/blacklist_token_contract_test.ts @@ -89,10 +89,8 @@ export class BlacklistTokenContractTest { this.admin = this.wallets[0]; this.other = this.wallets[1]; this.blacklisted = this.wallets[2]; - // this.accounts = this.wallets.map(a => a.getCompleteAddress()); this.accounts = await pxe.getRegisteredAccounts(); this.wallets.forEach((w, i) => this.logger.verbose(`Wallet ${i} address: ${w.getAddress()}`)); - this.accounts.forEach((w, i) => this.logger.verbose(`Account ${i} address: ${w.address}`)); }); await this.snapshotManager.snapshot( diff --git a/yarn-project/end-to-end/src/e2e_public_cross_chain_messaging.test.ts b/yarn-project/end-to-end/src/e2e_public_cross_chain_messaging.test.ts deleted file mode 100644 index 3edfd40cc7b0..000000000000 --- a/yarn-project/end-to-end/src/e2e_public_cross_chain_messaging.test.ts +++ /dev/null @@ -1,412 +0,0 @@ -import { - type AccountWallet, - type AztecAddress, - type AztecNode, - type DebugLogger, - type DeployL1Contracts, - EthAddress, - type EthAddressLike, - type FieldLike, - Fr, - L1Actor, - L1ToL2Message, - L2Actor, - type PXE, - computeAuthWitMessageHash, - computeSecretHash, -} from '@aztec/aztec.js'; -import { sha256ToField } from '@aztec/foundation/crypto'; -import { InboxAbi, OutboxAbi } from '@aztec/l1-artifacts'; -import { TestContract } from '@aztec/noir-contracts.js'; -import { type TokenContract } from '@aztec/noir-contracts.js/Token'; -import { type TokenBridgeContract } from '@aztec/noir-contracts.js/TokenBridge'; - -import { type Chain, type GetContractReturnType, type Hex, type HttpTransport, type PublicClient } from 'viem'; -import { decodeEventLog, toFunctionSelector } from 'viem/utils'; - -import { publicDeployAccounts, setup } from './fixtures/utils.js'; -import { CrossChainTestHarness } from './shared/cross_chain_test_harness.js'; - -describe('e2e_public_cross_chain_messaging', () => { - let aztecNode: AztecNode; - let pxe: PXE; - let deployL1ContractsValues: DeployL1Contracts; - let logger: DebugLogger; - let teardown: () => Promise; - let wallets: AccountWallet[]; - - let user1Wallet: AccountWallet; - let user2Wallet: AccountWallet; - let ethAccount: EthAddress; - let ownerAddress: AztecAddress; - - let crossChainTestHarness: CrossChainTestHarness; - let l2Token: TokenContract; - let l2Bridge: TokenBridgeContract; - let inbox: GetContractReturnType>; - let outbox: GetContractReturnType>; - - beforeAll(async () => { - ({ aztecNode, pxe, deployL1ContractsValues, wallets, logger, teardown } = await setup(2)); - user1Wallet = wallets[0]; - user2Wallet = wallets[1]; - await publicDeployAccounts(wallets[0], wallets.slice(0, 2)); - }, 45_000); - - beforeEach(async () => { - crossChainTestHarness = await CrossChainTestHarness.new( - aztecNode, - pxe, - deployL1ContractsValues.publicClient, - deployL1ContractsValues.walletClient, - wallets[0], - logger, - ); - l2Token = crossChainTestHarness.l2Token; - l2Bridge = crossChainTestHarness.l2Bridge; - ethAccount = crossChainTestHarness.ethAccount; - ownerAddress = crossChainTestHarness.ownerAddress; - inbox = crossChainTestHarness.inbox; - outbox = crossChainTestHarness.outbox; - - logger.info('Successfully deployed contracts and initialized portal'); - }, 100_000); - - afterAll(async () => { - await teardown(); - }); - - // docs:start:e2e_public_cross_chain - it('Publicly deposit funds from L1 -> L2 and withdraw back to L1', async () => { - // Generate a claim secret using pedersen - const l1TokenBalance = 1000000n; - const bridgeAmount = 100n; - - const [secret, secretHash] = crossChainTestHarness.generateClaimSecret(); - - // 1. Mint tokens on L1 - await crossChainTestHarness.mintTokensOnL1(l1TokenBalance); - - // 2. Deposit tokens to the TokenPortal - const msgHash = await crossChainTestHarness.sendTokensToPortalPublic(bridgeAmount, secretHash); - expect(await crossChainTestHarness.getL1BalanceOf(ethAccount)).toBe(l1TokenBalance - bridgeAmount); - - // Wait for the message to be available for consumption - await crossChainTestHarness.makeMessageConsumable(msgHash); - - // Get message leaf index, needed for claiming in public - const maybeIndexAndPath = await aztecNode.getL1ToL2MessageMembershipWitness('latest', msgHash, 0n); - expect(maybeIndexAndPath).toBeDefined(); - const messageLeafIndex = maybeIndexAndPath![0]; - - // 3. Consume L1 -> L2 message and mint public tokens on L2 - await crossChainTestHarness.consumeMessageOnAztecAndMintPublicly(bridgeAmount, secret, messageLeafIndex); - await crossChainTestHarness.expectPublicBalanceOnL2(ownerAddress, bridgeAmount); - const afterBalance = bridgeAmount; - - // time to withdraw the funds again! - logger.info('Withdrawing funds from L2'); - - // 4. Give approval to bridge to burn owner's funds: - const withdrawAmount = 9n; - const nonce = Fr.random(); - const burnMessageHash = computeAuthWitMessageHash( - l2Bridge.address, - wallets[0].getChainId(), - wallets[0].getVersion(), - l2Token.methods.burn_public(ownerAddress, withdrawAmount, nonce).request(), - ); - await user1Wallet.setPublicAuthWit(burnMessageHash, true).send().wait(); - - // 5. Withdraw owner's funds from L2 to L1 - const l2ToL1Message = crossChainTestHarness.getL2ToL1MessageLeaf(withdrawAmount); - const l2TxReceipt = await crossChainTestHarness.withdrawPublicFromAztecToL1(withdrawAmount, nonce); - await crossChainTestHarness.expectPublicBalanceOnL2(ownerAddress, afterBalance - withdrawAmount); - - // Check balance before and after exit. - expect(await crossChainTestHarness.getL1BalanceOf(ethAccount)).toBe(l1TokenBalance - bridgeAmount); - - const [l2ToL1MessageIndex, siblingPath] = await aztecNode.getL2ToL1MessageMembershipWitness( - l2TxReceipt.blockNumber!, - l2ToL1Message, - ); - - await crossChainTestHarness.withdrawFundsFromBridgeOnL1( - withdrawAmount, - l2TxReceipt.blockNumber!, - l2ToL1MessageIndex, - siblingPath, - ); - expect(await crossChainTestHarness.getL1BalanceOf(ethAccount)).toBe(l1TokenBalance - bridgeAmount + withdrawAmount); - }, 120_000); - // docs:end:e2e_public_cross_chain - - // Unit tests for TokenBridge's public methods. - - it('Someone else can mint funds to me on my behalf (publicly)', async () => { - // Generate a claim secret using pedersen - const l1TokenBalance = 1000000n; - const bridgeAmount = 100n; - - const [secret, secretHash] = crossChainTestHarness.generateClaimSecret(); - - await crossChainTestHarness.mintTokensOnL1(l1TokenBalance); - const msgHash = await crossChainTestHarness.sendTokensToPortalPublic(bridgeAmount, secretHash); - expect(await crossChainTestHarness.getL1BalanceOf(ethAccount)).toBe(l1TokenBalance - bridgeAmount); - - await crossChainTestHarness.makeMessageConsumable(msgHash); - - const content = sha256ToField([ - Buffer.from(toFunctionSelector('mint_public(bytes32,uint256)').substring(2), 'hex'), - user2Wallet.getAddress(), - new Fr(bridgeAmount), - ]); - const wrongMessage = new L1ToL2Message( - new L1Actor(crossChainTestHarness.tokenPortalAddress, crossChainTestHarness.publicClient.chain.id), - new L2Actor(l2Bridge.address, 1), - content, - secretHash, - ); - - // get message leaf index, needed for claiming in public - const maybeIndexAndPath = await aztecNode.getL1ToL2MessageMembershipWitness('latest', msgHash, 0n); - expect(maybeIndexAndPath).toBeDefined(); - const messageLeafIndex = maybeIndexAndPath![0]; - - // user2 tries to consume this message and minting to itself -> should fail since the message is intended to be consumed only by owner. - await expect( - l2Bridge - .withWallet(user2Wallet) - .methods.claim_public(user2Wallet.getAddress(), bridgeAmount, secret, messageLeafIndex) - .prove(), - ).rejects.toThrow(`No non-nullified L1 to L2 message found for message hash ${wrongMessage.hash().toString()}`); - - // user2 consumes owner's L1-> L2 message on bridge contract and mints public tokens on L2 - logger.info("user2 consumes owner's message on L2 Publicly"); - await l2Bridge - .withWallet(user2Wallet) - .methods.claim_public(ownerAddress, bridgeAmount, secret, messageLeafIndex) - .send() - .wait(); - // ensure funds are gone to owner and not user2. - await crossChainTestHarness.expectPublicBalanceOnL2(ownerAddress, bridgeAmount); - await crossChainTestHarness.expectPublicBalanceOnL2(user2Wallet.getAddress(), 0n); - }, 90_000); - - it("Bridge can't withdraw my funds if I don't give approval", async () => { - const mintAmountToOwner = 100n; - await crossChainTestHarness.mintTokensPublicOnL2(mintAmountToOwner); - - const withdrawAmount = 9n; - const nonce = Fr.random(); - // Should fail as owner has not given approval to bridge burn their funds. - await expect( - l2Bridge - .withWallet(user1Wallet) - .methods.exit_to_l1_public(ethAccount, withdrawAmount, EthAddress.ZERO, nonce) - .prove(), - ).rejects.toThrow('Assertion failed: Message not authorized by account'); - }, 60_000); - - it("can't claim funds privately which were intended for public deposit from the token portal", async () => { - const bridgeAmount = 100n; - const [secret, secretHash] = crossChainTestHarness.generateClaimSecret(); - - await crossChainTestHarness.mintTokensOnL1(bridgeAmount); - const msgHash = await crossChainTestHarness.sendTokensToPortalPublic(bridgeAmount, secretHash); - expect(await crossChainTestHarness.getL1BalanceOf(ethAccount)).toBe(0n); - - await crossChainTestHarness.makeMessageConsumable(msgHash); - - // Wrong message hash - const content = sha256ToField([ - Buffer.from(toFunctionSelector('mint_private(bytes32,uint256)').substring(2), 'hex'), - secretHash, - new Fr(bridgeAmount), - ]); - const wrongMessage = new L1ToL2Message( - new L1Actor(crossChainTestHarness.tokenPortalAddress, crossChainTestHarness.publicClient.chain.id), - new L2Actor(l2Bridge.address, 1), - content, - secretHash, - ); - - await expect( - l2Bridge.withWallet(user2Wallet).methods.claim_private(secretHash, bridgeAmount, secret).prove(), - ).rejects.toThrow(`No non-nullified L1 to L2 message found for message hash ${wrongMessage.hash().toString()}`); - }, 60_000); - - // Note: We register one portal address when deploying contract but that address is no-longer the only address - // allowed to receive messages from the given contract. In the following test we'll test that it's really the case. - it.each([true, false])( - 'can send an L2 -> L1 message to a non-registered portal address from private or public', - async (isPrivate: boolean) => { - const testContract = await TestContract.deploy(user1Wallet).send().deployed(); - - const content = Fr.random(); - const recipient = crossChainTestHarness.ethAccount; - - let l2TxReceipt; - - // We create the L2 -> L1 message using the test contract - if (isPrivate) { - l2TxReceipt = await testContract.methods - .create_l2_to_l1_message_arbitrary_recipient_private(content, recipient) - .send() - .wait(); - } else { - l2TxReceipt = await testContract.methods - .create_l2_to_l1_message_arbitrary_recipient_public(content, recipient) - .send() - .wait(); - } - - const l2ToL1Message = { - sender: { actor: testContract.address.toString() as Hex, version: 1n }, - recipient: { - actor: recipient.toString() as Hex, - chainId: BigInt(crossChainTestHarness.publicClient.chain.id), - }, - content: content.toString() as Hex, - }; - - const leaf = sha256ToField([ - testContract.address, - new Fr(1), // aztec version - recipient.toBuffer32(), - new Fr(crossChainTestHarness.publicClient.chain.id), // chain id - content, - ]); - - const [l2MessageIndex, siblingPath] = await aztecNode.getL2ToL1MessageMembershipWitness( - l2TxReceipt.blockNumber!, - leaf, - ); - - const txHash = await outbox.write.consume( - [ - l2ToL1Message, - BigInt(l2TxReceipt.blockNumber!), - BigInt(l2MessageIndex), - siblingPath.toBufferArray().map((buf: Buffer) => `0x${buf.toString('hex')}`) as readonly `0x${string}`[], - ], - {} as any, - ); - - const txReceipt = await crossChainTestHarness.publicClient.waitForTransactionReceipt({ - hash: txHash, - }); - - // Exactly 1 event should be emitted in the transaction - expect(txReceipt.logs.length).toBe(1); - - // We decode the event log before checking it - const txLog = txReceipt.logs[0]; - const topics = decodeEventLog({ - abi: OutboxAbi, - data: txLog.data, - topics: txLog.topics, - }) as { - eventName: 'MessageConsumed'; - args: { - l2BlockNumber: bigint; - root: `0x${string}`; - messageHash: `0x${string}`; - leafIndex: bigint; - }; - }; - - // We check that MessageConsumed event was emitted with the expected message hash and leaf index - expect(topics.args.messageHash).toStrictEqual(leaf.toString()); - expect(topics.args.leafIndex).toStrictEqual(BigInt(0)); - }, - 60_000, - ); - - // Note: We register one portal address when deploying contract but that address is no-longer the only address - // allowed to send messages to the given contract. In the following test we'll test that it's really the case. - it.each([true, false])( - 'can send an L1 -> L2 message from a non-registered portal address consumed from private or public and then sends and claims exactly the same message again', - async (isPrivate: boolean) => { - const testContract = await TestContract.deploy(user1Wallet).send().deployed(); - - const consumeMethod = isPrivate - ? (content: FieldLike, secret: FieldLike, sender: EthAddressLike, _leafIndex: FieldLike) => - testContract.methods.consume_message_from_arbitrary_sender_private(content, secret, sender) - : testContract.methods.consume_message_from_arbitrary_sender_public; - - const secret = Fr.random(); - - const message = new L1ToL2Message( - new L1Actor(crossChainTestHarness.ethAccount, crossChainTestHarness.publicClient.chain.id), - new L2Actor(testContract.address, 1), - Fr.random(), // content - computeSecretHash(secret), // secretHash - ); - - await sendL2Message(message); - - const [message1Index, _1] = (await aztecNode.getL1ToL2MessageMembershipWitness('latest', message.hash(), 0n))!; - - // Finally, we consume the L1 -> L2 message using the test contract either from private or public - await consumeMethod(message.content, secret, message.sender.sender, message1Index).send().wait(); - - // We send and consume the exact same message the second time to test that oracles correctly return the new - // non-nullified message - await sendL2Message(message); - - // We check that the duplicate message was correctly inserted by checking that its message index is defined and - // larger than the previous message index - const [message2Index, _2] = (await aztecNode.getL1ToL2MessageMembershipWitness( - 'latest', - message.hash(), - message1Index + 1n, - ))!; - - expect(message2Index).toBeDefined(); - expect(message2Index).toBeGreaterThan(message1Index); - - // Now we consume the message again. Everything should pass because oracle should return the duplicate message - // which is not nullified - await consumeMethod(message.content, secret, message.sender.sender, message2Index).send().wait(); - }, - 120_000, - ); - - const sendL2Message = async (message: L1ToL2Message) => { - // We inject the message to Inbox - const txHash = await inbox.write.sendL2Message( - [ - { actor: message.recipient.recipient.toString() as Hex, version: 1n }, - message.content.toString() as Hex, - message.secretHash.toString() as Hex, - ] as const, - {} as any, - ); - - // We check that the message was correctly injected by checking the emitted event - const msgHash = message.hash(); - { - const txReceipt = await crossChainTestHarness.publicClient.waitForTransactionReceipt({ - hash: txHash, - }); - - // Exactly 1 event should be emitted in the transaction - expect(txReceipt.logs.length).toBe(1); - - // We decode the event and get leaf out of it - const txLog = txReceipt.logs[0]; - const topics = decodeEventLog({ - abi: InboxAbi, - data: txLog.data, - topics: txLog.topics, - }); - const receivedMsgHash = topics.args.hash; - - // We check that the leaf inserted into the subtree matches the expected message hash - expect(receivedMsgHash).toBe(msgHash.toString()); - } - - await crossChainTestHarness.makeMessageConsumable(msgHash); - }; -}); diff --git a/yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/deposits.test.ts b/yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/deposits.test.ts new file mode 100644 index 000000000000..566118a99cc1 --- /dev/null +++ b/yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/deposits.test.ts @@ -0,0 +1,161 @@ +import { Fr, L1Actor, L1ToL2Message, L2Actor, computeAuthWitMessageHash } from '@aztec/aztec.js'; +import { sha256ToField } from '@aztec/foundation/crypto'; + +import { toFunctionSelector } from 'viem'; + +import { PublicCrossChainMessagingContractTest } from './public_cross_chain_messaging_contract_test.js'; + +describe('e2e_public_cross_chain_messaging deposits', () => { + const t = new PublicCrossChainMessagingContractTest('deposits'); + + let { + wallets, + crossChainTestHarness, + ethAccount, + aztecNode, + logger, + ownerAddress, + l2Bridge, + l2Token, + user1Wallet, + user2Wallet, + } = t; + + beforeEach(async () => { + await t.applyBaseSnapshots(); + await t.setup(); + // Have to destructure again to ensure we have latest refs. + ({ wallets, crossChainTestHarness, user1Wallet, user2Wallet } = t); + + ethAccount = crossChainTestHarness.ethAccount; + aztecNode = crossChainTestHarness.aztecNode; + logger = crossChainTestHarness.logger; + ownerAddress = crossChainTestHarness.ownerAddress; + l2Bridge = crossChainTestHarness.l2Bridge; + l2Token = crossChainTestHarness.l2Token; + }, 200_000); + + afterEach(async () => { + await t.teardown(); + }); + + // docs:start:e2e_public_cross_chain + it('Publicly deposit funds from L1 -> L2 and withdraw back to L1', async () => { + // Generate a claim secret using pedersen + const l1TokenBalance = 1000000n; + const bridgeAmount = 100n; + + const [secret, secretHash] = crossChainTestHarness.generateClaimSecret(); + + // 1. Mint tokens on L1 + logger.verbose(`1. Mint tokens on L1`); + await crossChainTestHarness.mintTokensOnL1(l1TokenBalance); + + // 2. Deposit tokens to the TokenPortal + logger.verbose(`2. Deposit tokens to the TokenPortal`); + const msgHash = await crossChainTestHarness.sendTokensToPortalPublic(bridgeAmount, secretHash); + expect(await crossChainTestHarness.getL1BalanceOf(ethAccount)).toBe(l1TokenBalance - bridgeAmount); + + // Wait for the message to be available for consumption + logger.verbose(`Wait for the message to be available for consumption`); + await crossChainTestHarness.makeMessageConsumable(msgHash); + + // Get message leaf index, needed for claiming in public + const maybeIndexAndPath = await aztecNode.getL1ToL2MessageMembershipWitness('latest', msgHash, 0n); + expect(maybeIndexAndPath).toBeDefined(); + const messageLeafIndex = maybeIndexAndPath![0]; + + // 3. Consume L1 -> L2 message and mint public tokens on L2 + logger.verbose('3. Consume L1 -> L2 message and mint public tokens on L2'); + await crossChainTestHarness.consumeMessageOnAztecAndMintPublicly(bridgeAmount, secret, messageLeafIndex); + await crossChainTestHarness.expectPublicBalanceOnL2(ownerAddress, bridgeAmount); + const afterBalance = bridgeAmount; + + // time to withdraw the funds again! + logger.info('Withdrawing funds from L2'); + + // 4. Give approval to bridge to burn owner's funds: + const withdrawAmount = 9n; + const nonce = Fr.random(); + const burnMessageHash = computeAuthWitMessageHash( + l2Bridge.address, + wallets[0].getChainId(), + wallets[0].getVersion(), + l2Token.methods.burn_public(ownerAddress, withdrawAmount, nonce).request(), + ); + await user1Wallet.setPublicAuthWit(burnMessageHash, true).send().wait(); + + // 5. Withdraw owner's funds from L2 to L1 + logger.verbose('5. Withdraw owner funds from L2 to L1'); + const l2ToL1Message = crossChainTestHarness.getL2ToL1MessageLeaf(withdrawAmount); + const l2TxReceipt = await crossChainTestHarness.withdrawPublicFromAztecToL1(withdrawAmount, nonce); + await crossChainTestHarness.expectPublicBalanceOnL2(ownerAddress, afterBalance - withdrawAmount); + + // Check balance before and after exit. + expect(await crossChainTestHarness.getL1BalanceOf(ethAccount)).toBe(l1TokenBalance - bridgeAmount); + + const [l2ToL1MessageIndex, siblingPath] = await aztecNode.getL2ToL1MessageMembershipWitness( + l2TxReceipt.blockNumber!, + l2ToL1Message, + ); + + await crossChainTestHarness.withdrawFundsFromBridgeOnL1( + withdrawAmount, + l2TxReceipt.blockNumber!, + l2ToL1MessageIndex, + siblingPath, + ); + expect(await crossChainTestHarness.getL1BalanceOf(ethAccount)).toBe(l1TokenBalance - bridgeAmount + withdrawAmount); + }, 120_000); + // docs:end:e2e_public_cross_chain + + it('Someone else can mint funds to me on my behalf (publicly)', async () => { + // Generate a claim secret using pedersen + const l1TokenBalance = 1000000n; + const bridgeAmount = 100n; + + const [secret, secretHash] = crossChainTestHarness.generateClaimSecret(); + + await crossChainTestHarness.mintTokensOnL1(l1TokenBalance); + const msgHash = await crossChainTestHarness.sendTokensToPortalPublic(bridgeAmount, secretHash); + expect(await crossChainTestHarness.getL1BalanceOf(ethAccount)).toBe(l1TokenBalance - bridgeAmount); + + await crossChainTestHarness.makeMessageConsumable(msgHash); + + const content = sha256ToField([ + Buffer.from(toFunctionSelector('mint_public(bytes32,uint256)').substring(2), 'hex'), + user2Wallet.getAddress(), + new Fr(bridgeAmount), + ]); + const wrongMessage = new L1ToL2Message( + new L1Actor(crossChainTestHarness.tokenPortalAddress, crossChainTestHarness.publicClient.chain.id), + new L2Actor(l2Bridge.address, 1), + content, + secretHash, + ); + + // get message leaf index, needed for claiming in public + const maybeIndexAndPath = await aztecNode.getL1ToL2MessageMembershipWitness('latest', msgHash, 0n); + expect(maybeIndexAndPath).toBeDefined(); + const messageLeafIndex = maybeIndexAndPath![0]; + + // user2 tries to consume this message and minting to itself -> should fail since the message is intended to be consumed only by owner. + await expect( + l2Bridge + .withWallet(user2Wallet) + .methods.claim_public(user2Wallet.getAddress(), bridgeAmount, secret, messageLeafIndex) + .prove(), + ).rejects.toThrow(`No non-nullified L1 to L2 message found for message hash ${wrongMessage.hash().toString()}`); + + // user2 consumes owner's L1-> L2 message on bridge contract and mints public tokens on L2 + logger.info("user2 consumes owner's message on L2 Publicly"); + await l2Bridge + .withWallet(user2Wallet) + .methods.claim_public(ownerAddress, bridgeAmount, secret, messageLeafIndex) + .send() + .wait(); + // ensure funds are gone to owner and not user2. + await crossChainTestHarness.expectPublicBalanceOnL2(ownerAddress, bridgeAmount); + await crossChainTestHarness.expectPublicBalanceOnL2(user2Wallet.getAddress(), 0n); + }, 90_000); +}); diff --git a/yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/failure_cases.test.ts b/yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/failure_cases.test.ts new file mode 100644 index 000000000000..8e8bb4f1bb98 --- /dev/null +++ b/yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/failure_cases.test.ts @@ -0,0 +1,68 @@ +import { EthAddress, Fr, L1Actor, L1ToL2Message, L2Actor } from '@aztec/aztec.js'; +import { sha256ToField } from '@aztec/foundation/crypto'; + +import { toFunctionSelector } from 'viem'; + +import { PublicCrossChainMessagingContractTest } from './public_cross_chain_messaging_contract_test.js'; + +describe('e2e_public_cross_chain_messaging failures', () => { + const t = new PublicCrossChainMessagingContractTest('failures'); + + let { crossChainTestHarness, ethAccount, l2Bridge, user1Wallet, user2Wallet } = t; + + beforeAll(async () => { + await t.applyBaseSnapshots(); + await t.setup(); + // Have to destructure again to ensure we have latest refs. + ({ crossChainTestHarness, user1Wallet, user2Wallet } = t); + ethAccount = crossChainTestHarness.ethAccount; + l2Bridge = crossChainTestHarness.l2Bridge; + }, 200_000); + + afterAll(async () => { + await t.teardown(); + }); + + it("Bridge can't withdraw my funds if I don't give approval", async () => { + const mintAmountToOwner = 100n; + await crossChainTestHarness.mintTokensPublicOnL2(mintAmountToOwner); + + const withdrawAmount = 9n; + const nonce = Fr.random(); + // Should fail as owner has not given approval to bridge burn their funds. + await expect( + l2Bridge + .withWallet(user1Wallet) + .methods.exit_to_l1_public(ethAccount, withdrawAmount, EthAddress.ZERO, nonce) + .prove(), + ).rejects.toThrow('Assertion failed: Message not authorized by account'); + }, 60_000); + + it("can't claim funds privately which were intended for public deposit from the token portal", async () => { + const bridgeAmount = 100n; + const [secret, secretHash] = crossChainTestHarness.generateClaimSecret(); + + await crossChainTestHarness.mintTokensOnL1(bridgeAmount); + const msgHash = await crossChainTestHarness.sendTokensToPortalPublic(bridgeAmount, secretHash); + expect(await crossChainTestHarness.getL1BalanceOf(ethAccount)).toBe(0n); + + await crossChainTestHarness.makeMessageConsumable(msgHash); + + // Wrong message hash + const content = sha256ToField([ + Buffer.from(toFunctionSelector('mint_private(bytes32,uint256)').substring(2), 'hex'), + secretHash, + new Fr(bridgeAmount), + ]); + const wrongMessage = new L1ToL2Message( + new L1Actor(crossChainTestHarness.tokenPortalAddress, crossChainTestHarness.publicClient.chain.id), + new L2Actor(l2Bridge.address, 1), + content, + secretHash, + ); + + await expect( + l2Bridge.withWallet(user2Wallet).methods.claim_private(secretHash, bridgeAmount, secret).prove(), + ).rejects.toThrow(`No non-nullified L1 to L2 message found for message hash ${wrongMessage.hash().toString()}`); + }, 60_000); +}); diff --git a/yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/l1_to_l2.test.ts b/yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/l1_to_l2.test.ts new file mode 100644 index 000000000000..9285a56a36d6 --- /dev/null +++ b/yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/l1_to_l2.test.ts @@ -0,0 +1,122 @@ +import { + type EthAddressLike, + type FieldLike, + Fr, + L1Actor, + L1ToL2Message, + L2Actor, + computeSecretHash, +} from '@aztec/aztec.js'; +import { InboxAbi } from '@aztec/l1-artifacts'; +import { TestContract } from '@aztec/noir-contracts.js'; + +import { type Hex, decodeEventLog } from 'viem'; + +import { PublicCrossChainMessagingContractTest } from './public_cross_chain_messaging_contract_test.js'; + +describe('e2e_public_cross_chain_messaging l1_to_l2', () => { + const t = new PublicCrossChainMessagingContractTest('l1_to_l2'); + + let { crossChainTestHarness, aztecNode, user1Wallet, inbox } = t; + + beforeAll(async () => { + await t.applyBaseSnapshots(); + await t.setup(); + // Have to destructure again to ensure we have latest refs. + ({ crossChainTestHarness, user1Wallet } = t); + + aztecNode = crossChainTestHarness.aztecNode; + inbox = crossChainTestHarness.inbox; + }, 200_000); + + afterAll(async () => { + await t.teardown(); + }); + + // Note: We register one portal address when deploying contract but that address is no-longer the only address + // allowed to send messages to the given contract. In the following test we'll test that it's really the case. + it.each([true, false])( + 'can send an L1 -> L2 message from a non-registered portal address consumed from private or public and then sends and claims exactly the same message again', + async (isPrivate: boolean) => { + const testContract = await TestContract.deploy(user1Wallet).send().deployed(); + + const consumeMethod = isPrivate + ? (content: FieldLike, secret: FieldLike, sender: EthAddressLike, _leafIndex: FieldLike) => + testContract.methods.consume_message_from_arbitrary_sender_private(content, secret, sender) + : testContract.methods.consume_message_from_arbitrary_sender_public; + + const secret = Fr.random(); + + const message = new L1ToL2Message( + new L1Actor(crossChainTestHarness.ethAccount, crossChainTestHarness.publicClient.chain.id), + new L2Actor(testContract.address, 1), + Fr.random(), // content + computeSecretHash(secret), // secretHash + ); + + await sendL2Message(message); + + const [message1Index, _1] = (await aztecNode.getL1ToL2MessageMembershipWitness('latest', message.hash(), 0n))!; + + // Finally, we consume the L1 -> L2 message using the test contract either from private or public + await consumeMethod(message.content, secret, message.sender.sender, message1Index).send().wait(); + + // We send and consume the exact same message the second time to test that oracles correctly return the new + // non-nullified message + await sendL2Message(message); + + // We check that the duplicate message was correctly inserted by checking that its message index is defined and + // larger than the previous message index + const [message2Index, _2] = (await aztecNode.getL1ToL2MessageMembershipWitness( + 'latest', + message.hash(), + message1Index + 1n, + ))!; + + expect(message2Index).toBeDefined(); + expect(message2Index).toBeGreaterThan(message1Index); + + // Now we consume the message again. Everything should pass because oracle should return the duplicate message + // which is not nullified + await consumeMethod(message.content, secret, message.sender.sender, message2Index).send().wait(); + }, + 120_000, + ); + + const sendL2Message = async (message: L1ToL2Message) => { + // We inject the message to Inbox + const txHash = await inbox.write.sendL2Message( + [ + { actor: message.recipient.recipient.toString() as Hex, version: 1n }, + message.content.toString() as Hex, + message.secretHash.toString() as Hex, + ] as const, + {} as any, + ); + + // We check that the message was correctly injected by checking the emitted event + const msgHash = message.hash(); + { + const txReceipt = await crossChainTestHarness.publicClient.waitForTransactionReceipt({ + hash: txHash, + }); + + // Exactly 1 event should be emitted in the transaction + expect(txReceipt.logs.length).toBe(1); + + // We decode the event and get leaf out of it + const txLog = txReceipt.logs[0]; + const topics = decodeEventLog({ + abi: InboxAbi, + data: txLog.data, + topics: txLog.topics, + }); + const receivedMsgHash = topics.args.hash; + + // We check that the leaf inserted into the subtree matches the expected message hash + expect(receivedMsgHash).toBe(msgHash.toString()); + } + + await crossChainTestHarness.makeMessageConsumable(msgHash); + }; +}); diff --git a/yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/l2_to_l1.test.ts b/yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/l2_to_l1.test.ts new file mode 100644 index 000000000000..b43333c5edf7 --- /dev/null +++ b/yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/l2_to_l1.test.ts @@ -0,0 +1,116 @@ +import { Fr } from '@aztec/aztec.js'; +import { sha256ToField } from '@aztec/foundation/crypto'; +import { OutboxAbi } from '@aztec/l1-artifacts'; +import { TestContract } from '@aztec/noir-contracts.js'; + +import { type Hex, decodeEventLog } from 'viem'; + +import { PublicCrossChainMessagingContractTest } from './public_cross_chain_messaging_contract_test.js'; + +describe('e2e_public_cross_chain_messaging l2_to_l1', () => { + const t = new PublicCrossChainMessagingContractTest('l2_to_l1'); + + let { crossChainTestHarness, aztecNode, user1Wallet, outbox } = t; + + beforeAll(async () => { + await t.applyBaseSnapshots(); + await t.setup(); + // Have to destructure again to ensure we have latest refs. + ({ crossChainTestHarness, user1Wallet } = t); + + aztecNode = crossChainTestHarness.aztecNode; + + outbox = crossChainTestHarness.outbox; + }, 200_000); + + afterAll(async () => { + await t.teardown(); + }); + + // Note: We register one portal address when deploying contract but that address is no-longer the only address + // allowed to receive messages from the given contract. In the following test we'll test that it's really the case. + it.each([[true], [false]])( + `can send an L2 -> L1 message to a non-registered portal address from public or private`, + async (isPrivate: boolean) => { + const testContract = await TestContract.deploy(user1Wallet).send().deployed(); + + const content = Fr.random(); + const recipient = crossChainTestHarness.ethAccount; + + let l2TxReceipt; + + // We create the L2 -> L1 message using the test contract + if (isPrivate) { + l2TxReceipt = await testContract.methods + .create_l2_to_l1_message_arbitrary_recipient_private(content, recipient) + .send() + .wait(); + } else { + l2TxReceipt = await testContract.methods + .create_l2_to_l1_message_arbitrary_recipient_public(content, recipient) + .send() + .wait(); + } + + const l2ToL1Message = { + sender: { actor: testContract.address.toString() as Hex, version: 1n }, + recipient: { + actor: recipient.toString() as Hex, + chainId: BigInt(crossChainTestHarness.publicClient.chain.id), + }, + content: content.toString() as Hex, + }; + + const leaf = sha256ToField([ + testContract.address, + new Fr(1), // aztec version + recipient.toBuffer32(), + new Fr(crossChainTestHarness.publicClient.chain.id), // chain id + content, + ]); + + const [l2MessageIndex, siblingPath] = await aztecNode.getL2ToL1MessageMembershipWitness( + l2TxReceipt.blockNumber!, + leaf, + ); + + const txHash = await outbox.write.consume( + [ + l2ToL1Message, + BigInt(l2TxReceipt.blockNumber!), + BigInt(l2MessageIndex), + siblingPath.toBufferArray().map((buf: Buffer) => `0x${buf.toString('hex')}`) as readonly `0x${string}`[], + ], + {} as any, + ); + + const txReceipt = await crossChainTestHarness.publicClient.waitForTransactionReceipt({ + hash: txHash, + }); + + // Exactly 1 event should be emitted in the transaction + expect(txReceipt.logs.length).toBe(1); + + // We decode the event log before checking it + const txLog = txReceipt.logs[0]; + const topics = decodeEventLog({ + abi: OutboxAbi, + data: txLog.data, + topics: txLog.topics, + }) as { + eventName: 'MessageConsumed'; + args: { + l2BlockNumber: bigint; + root: `0x${string}`; + messageHash: `0x${string}`; + leafIndex: bigint; + }; + }; + + // We check that MessageConsumed event was emitted with the expected message hash and leaf index + expect(topics.args.messageHash).toStrictEqual(leaf.toString()); + expect(topics.args.leafIndex).toStrictEqual(BigInt(0)); + }, + 60_000, + ); +}); diff --git a/yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/public_cross_chain_messaging_contract_test.ts b/yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/public_cross_chain_messaging_contract_test.ts new file mode 100644 index 000000000000..38add3ebc8a4 --- /dev/null +++ b/yarn-project/end-to-end/src/e2e_public_cross_chain_messaging/public_cross_chain_messaging_contract_test.ts @@ -0,0 +1,225 @@ +import { getSchnorrAccount } from '@aztec/accounts/schnorr'; +import { type AztecNodeConfig } from '@aztec/aztec-node'; +import { + type AccountWallet, + AztecAddress, + type AztecNode, + type CompleteAddress, + type DebugLogger, + EthAddress, + type PXE, + createDebugLogger, +} from '@aztec/aztec.js'; +import { InboxAbi, OutboxAbi, PortalERC20Abi, TokenPortalAbi } from '@aztec/l1-artifacts'; +import { TokenBridgeContract, TokenContract } from '@aztec/noir-contracts.js'; + +import { + type Chain, + type HttpTransport, + type PublicClient, + createPublicClient, + createWalletClient, + getContract, + http, +} from 'viem'; +import { mnemonicToAccount } from 'viem/accounts'; +import { foundry } from 'viem/chains'; + +import { MNEMONIC } from '../fixtures/fixtures.js'; +import { + SnapshotManager, + type SubsystemsContext, + addAccounts, + publicDeployAccounts, +} from '../fixtures/snapshot_manager.js'; +import { CrossChainTestHarness } from '../shared/cross_chain_test_harness.js'; + +const { E2E_DATA_PATH: dataPath } = process.env; + +export class PublicCrossChainMessagingContractTest { + private snapshotManager: SnapshotManager; + logger: DebugLogger; + wallets: AccountWallet[] = []; + accounts: CompleteAddress[] = []; + aztecNode!: AztecNode; + pxe!: PXE; + aztecNodeConfig!: AztecNodeConfig; + + publicClient!: PublicClient | undefined; + + user1Wallet!: AccountWallet; + user2Wallet!: AccountWallet; + crossChainTestHarness!: CrossChainTestHarness; + ethAccount!: EthAddress; + ownerAddress!: AztecAddress; + l2Token!: TokenContract; + l2Bridge!: TokenBridgeContract; + + inbox!: any; // GetContractReturnType | undefined; + outbox!: any; // GetContractReturnType | undefined; + + constructor(testName: string) { + this.logger = createDebugLogger(`aztec:e2e_public_cross_chain_messaging:${testName}`); + this.snapshotManager = new SnapshotManager(`e2e_public_cross_chain_messaging/${testName}`, dataPath); + } + + async setup() { + const { aztecNode, pxe, aztecNodeConfig } = await this.snapshotManager.setup(); + this.aztecNode = aztecNode; + this.pxe = pxe; + this.aztecNodeConfig = aztecNodeConfig; + } + + snapshot = ( + name: string, + apply: (context: SubsystemsContext) => Promise, + restore: (snapshotData: T, context: SubsystemsContext) => Promise = () => Promise.resolve(), + ): Promise => this.snapshotManager.snapshot(name, apply, restore); + + async teardown() { + await this.snapshotManager.teardown(); + } + + viemStuff(rpcUrl: string) { + const hdAccount = mnemonicToAccount(MNEMONIC); + + const walletClient = createWalletClient({ + account: hdAccount, + chain: foundry, + transport: http(rpcUrl), + }); + const publicClient = createPublicClient({ + chain: foundry, + transport: http(rpcUrl), + }); + + return { walletClient, publicClient }; + } + + async applyBaseSnapshots() { + // Note that we are using the same `pxe`, `aztecNodeConfig` and `aztecNode` across all snapshots. + // This is to not have issues with different networks. + + await this.snapshotManager.snapshot( + '3_accounts', + addAccounts(3, this.logger), + async ({ accountKeys }, { pxe, aztecNodeConfig, aztecNode }) => { + const accountManagers = accountKeys.map(ak => getSchnorrAccount(pxe, ak[0], ak[1], 1)); + this.wallets = await Promise.all(accountManagers.map(a => a.getWallet())); + this.wallets.forEach((w, i) => this.logger.verbose(`Wallet ${i} address: ${w.getAddress()}`)); + this.accounts = await pxe.getRegisteredAccounts(); + + this.user1Wallet = this.wallets[0]; + this.user2Wallet = this.wallets[1]; + + this.pxe = pxe; + this.aztecNode = aztecNode; + this.aztecNodeConfig = aztecNodeConfig; + }, + ); + + await this.snapshotManager.snapshot( + 'e2e_public_cross_chain_messaging', + async () => { + // Create the token contract state. + // Move this account thing to addAccounts above? + this.logger.verbose(`Public deploy accounts...`); + await publicDeployAccounts(this.wallets[0], this.accounts.slice(0, 3)); + + const { publicClient, walletClient } = this.viemStuff(this.aztecNodeConfig.rpcUrl); + + this.logger.verbose(`Setting up cross chain harness...`); + this.crossChainTestHarness = await CrossChainTestHarness.new( + this.aztecNode, + this.pxe, + publicClient, + walletClient, + this.wallets[0], + this.logger, + ); + + this.logger.verbose(`L2 token deployed to: ${this.crossChainTestHarness.l2Token.address}`); + + return this.toCrossChainContext(); + }, + async crossChainContext => { + this.l2Token = await TokenContract.at(crossChainContext.l2Token, this.user1Wallet); + this.l2Bridge = await TokenBridgeContract.at(crossChainContext.l2Bridge, this.user1Wallet); + + // There is an issue with the reviver so we are getting strings sometimes. Working around it here. + this.ethAccount = EthAddress.fromString(crossChainContext.ethAccount.toString()); + this.ownerAddress = AztecAddress.fromString(crossChainContext.ownerAddress.toString()); + const tokenPortalAddress = EthAddress.fromString(crossChainContext.tokenPortal.toString()); + + const { publicClient, walletClient } = this.viemStuff(this.aztecNodeConfig.rpcUrl); + + const inbox = getContract({ + address: this.aztecNodeConfig.l1Contracts.inboxAddress.toString(), + abi: InboxAbi, + client: walletClient, + }); + const outbox = getContract({ + address: this.aztecNodeConfig.l1Contracts.outboxAddress.toString(), + abi: OutboxAbi, + client: walletClient, + }); + + const tokenPortal = getContract({ + address: tokenPortalAddress.toString(), + abi: TokenPortalAbi, + client: walletClient, + }); + const underlyingERC20 = getContract({ + address: crossChainContext.underlying.toString(), + abi: PortalERC20Abi, + client: walletClient, + }); + + this.crossChainTestHarness = new CrossChainTestHarness( + this.aztecNode, + this.pxe, + this.logger, + this.l2Token, + this.l2Bridge, + this.ethAccount, + tokenPortalAddress, + tokenPortal, + underlyingERC20, + inbox, + outbox, + publicClient, + walletClient, + this.ownerAddress, + ); + + this.publicClient = publicClient; + this.inbox = inbox; + this.outbox = outbox; + }, + ); + } + + toCrossChainContext(): CrossChainContext { + return { + l2Token: this.crossChainTestHarness.l2Token.address, + l2Bridge: this.crossChainTestHarness.l2Bridge.address, + tokenPortal: this.crossChainTestHarness.tokenPortal.address, + underlying: EthAddress.fromString(this.crossChainTestHarness.underlyingERC20.address), + ethAccount: this.crossChainTestHarness.ethAccount, + ownerAddress: this.crossChainTestHarness.ownerAddress, + inbox: EthAddress.fromString(this.crossChainTestHarness.inbox.address), + outbox: EthAddress.fromString(this.crossChainTestHarness.outbox.address), + }; + } +} + +type CrossChainContext = { + l2Token: AztecAddress; + l2Bridge: AztecAddress; + tokenPortal: EthAddress; + underlying: EthAddress; + ethAccount: EthAddress; + ownerAddress: AztecAddress; + inbox: EthAddress; + outbox: EthAddress; +}; From 1449c338ca79f8d72b71484546aa46ddebb21779 Mon Sep 17 00:00:00 2001 From: ledwards2225 <98505400+ledwards2225@users.noreply.github.com> Date: Mon, 29 Apr 2024 16:28:03 -0700 Subject: [PATCH 03/42] fix: fix relation skipping for sumcheck (#6092) Due to a bad merge (or something?), the relation skipping for Sumcheck was not activated in the original PR. This fix corrects the syntax of the skipping check which turns on the skipping for Sumcheck only (full proof construction / decider). Branch (ultra honk bench) ``` --------------------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------------------- construct_proof_ultrahonk/sha256 324 ms 270 ms 3 construct_proof_ultrahonk/keccak 1235 ms 1023 ms 1 construct_proof_ultrahonk/ecdsa_verification 2402 ms 1996 ms 1 construct_proof_ultrahonk/merkle_membership 188 ms 156 ms 4 construct_proof_ultrahonk_power_of_2/15 200 ms 185 ms 4 construct_proof_ultrahonk_power_of_2/16 362 ms 335 ms 2 construct_proof_ultrahonk_power_of_2/17 684 ms 632 ms 1 construct_proof_ultrahonk_power_of_2/18 1350 ms 1248 ms 1 construct_proof_ultrahonk_power_of_2/19 2698 ms 2481 ms 1 construct_proof_ultrahonk_power_of_2/20 5176 ms 4784 ms 1 ``` Master (ultra honk bench) ``` --------------------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------------------- construct_proof_ultrahonk/sha256 423 ms 395 ms 2 construct_proof_ultrahonk/keccak 1628 ms 1515 ms 1 construct_proof_ultrahonk/ecdsa_verification 2971 ms 2764 ms 1 construct_proof_ultrahonk/merkle_membership 236 ms 218 ms 3 construct_proof_ultrahonk_power_of_2/15 248 ms 231 ms 3 construct_proof_ultrahonk_power_of_2/16 459 ms 430 ms 2 construct_proof_ultrahonk_power_of_2/17 872 ms 820 ms 1 construct_proof_ultrahonk_power_of_2/18 1773 ms 1656 ms 1 construct_proof_ultrahonk_power_of_2/19 3342 ms 3133 ms 1 construct_proof_ultrahonk_power_of_2/20 6694 ms 6289 ms 1 ``` Branch (client IVC) ``` -------------------------------------------------------------------------------- Benchmark Time CPU Iterations UserCounters... -------------------------------------------------------------------------------- ClientIVCBench/Full/6 21234 ms 16030 ms 1 function ms % sum construct_circuits(t) 4537 21.52% ProverInstance(Circuit&)(t) 1912 9.07% ProtogalaxyProver::fold_instances(t) 11102 52.64% Decider::construct_proof(t) 577 2.73% ECCVMProver(CircuitBuilder&)(t) 129 0.61% ECCVMProver::construct_proof(t) 1765 8.37% GoblinTranslatorProver::construct_proof(t) 930 4.41% Goblin::merge(t) 137 0.65% ``` Master (client IVC) ``` -------------------------------------------------------------------------------- Benchmark Time CPU Iterations UserCounters... -------------------------------------------------------------------------------- ClientIVCBench/Full/6 21385 ms 16321 ms 1 function ms % sum construct_circuits(t) 4544 21.40% ProverInstance(Circuit&)(t) 1919 9.03% ProtogalaxyProver::fold_instances(t) 11055 52.05% Decider::construct_proof(t) 728 3.43% ECCVMProver(CircuitBuilder&)(t) 132 0.62% ECCVMProver::construct_proof(t) 1757 8.27% GoblinTranslatorProver::construct_proof(t) 965 4.55% Goblin::merge(t) 138 0.65% ``` --- barretenberg/cpp/src/barretenberg/relations/relation_types.hpp | 2 +- barretenberg/cpp/src/barretenberg/sumcheck/sumcheck_round.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/relations/relation_types.hpp b/barretenberg/cpp/src/barretenberg/relations/relation_types.hpp index aa4ba8820b7c..e134c4e03413 100644 --- a/barretenberg/cpp/src/barretenberg/relations/relation_types.hpp +++ b/barretenberg/cpp/src/barretenberg/relations/relation_types.hpp @@ -122,7 +122,7 @@ consteval std::array compute_composed_subrelation_part template concept isSkippable = requires(const AllEntities& input) { { - Relation::is_active(input) + Relation::skip(input) } -> std::same_as; }; diff --git a/barretenberg/cpp/src/barretenberg/sumcheck/sumcheck_round.hpp b/barretenberg/cpp/src/barretenberg/sumcheck/sumcheck_round.hpp index de237e915de3..c1ac763379ab 100644 --- a/barretenberg/cpp/src/barretenberg/sumcheck/sumcheck_round.hpp +++ b/barretenberg/cpp/src/barretenberg/sumcheck/sumcheck_round.hpp @@ -245,7 +245,7 @@ template class SumcheckProverRound { std::get(univariate_accumulators), extended_edges, relation_parameters, scaling_factor); } else { // If so, only compute the contribution if the relation is active - if (!Relation::skip(extended_edges, relation_parameters)) { + if (!Relation::skip(extended_edges)) { Relation::accumulate(std::get(univariate_accumulators), extended_edges, relation_parameters, From 99fbb74f48f7ceb2cc88d1e5bf9b91b8d61d0493 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Tue, 30 Apr 2024 02:09:40 +0000 Subject: [PATCH 04/42] git subrepo push --branch=master barretenberg subrepo: subdir: "barretenberg" merged: "31ec60891" upstream: origin: "https://github.com/AztecProtocol/barretenberg" branch: "master" commit: "31ec60891" git-subrepo: version: "0.4.6" origin: "???" commit: "???" [skip ci] --- barretenberg/.gitrepo | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/barretenberg/.gitrepo b/barretenberg/.gitrepo index 3232d03252c3..ee434d12668a 100644 --- a/barretenberg/.gitrepo +++ b/barretenberg/.gitrepo @@ -6,7 +6,7 @@ [subrepo] remote = https://github.com/AztecProtocol/barretenberg branch = master - commit = 795c999a3b7fe8d85af05ffc09ddfd349c00e5a4 - parent = 0a64279ba1b2b3bb6627c675b8a0b116be17f579 + commit = 31ec6089135e12bf85240d50bc8bac066918dfa0 + parent = 1449c338ca79f8d72b71484546aa46ddebb21779 method = merge cmdver = 0.4.6 From b723534db2fcbd3399aca722354df7c45ee8a84f Mon Sep 17 00:00:00 2001 From: AztecBot Date: Tue, 30 Apr 2024 02:10:07 +0000 Subject: [PATCH 05/42] chore: replace relative paths to noir-protocol-circuits --- noir-projects/aztec-nr/aztec/Nargo.toml | 2 +- noir-projects/aztec-nr/tests/Nargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/noir-projects/aztec-nr/aztec/Nargo.toml b/noir-projects/aztec-nr/aztec/Nargo.toml index 7a1f1af58631..2cbb43ab2787 100644 --- a/noir-projects/aztec-nr/aztec/Nargo.toml +++ b/noir-projects/aztec-nr/aztec/Nargo.toml @@ -5,4 +5,4 @@ compiler_version = ">=0.18.0" type = "lib" [dependencies] -protocol_types = { path = "../../noir-protocol-circuits/crates/types" } +protocol_types = { git="https://github.com/AztecProtocol/aztec-packages", tag="aztec-packages-v0.35.1", directory="noir-projects/noir-protocol-circuits/crates/types" } diff --git a/noir-projects/aztec-nr/tests/Nargo.toml b/noir-projects/aztec-nr/tests/Nargo.toml index 13404b373243..dfed895aad0c 100644 --- a/noir-projects/aztec-nr/tests/Nargo.toml +++ b/noir-projects/aztec-nr/tests/Nargo.toml @@ -6,4 +6,4 @@ type = "lib" [dependencies] aztec = { path = "../aztec" } -protocol_types = { path = "../../noir-protocol-circuits/crates/types" } +protocol_types = { git="https://github.com/AztecProtocol/aztec-packages", tag="aztec-packages-v0.35.1", directory="noir-projects/noir-protocol-circuits/crates/types" } From 655bb8f87242b0cab45961b86825ab1db202d705 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Tue, 30 Apr 2024 02:10:07 +0000 Subject: [PATCH 06/42] git_subrepo.sh: Fix parent in .gitrepo file. [skip ci] --- noir-projects/aztec-nr/.gitrepo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/noir-projects/aztec-nr/.gitrepo b/noir-projects/aztec-nr/.gitrepo index 6d9c9a29f805..425f564cc489 100644 --- a/noir-projects/aztec-nr/.gitrepo +++ b/noir-projects/aztec-nr/.gitrepo @@ -9,4 +9,4 @@ commit = 071b146a0fa3951fdd05b2a2732bac331bc79f73 method = merge cmdver = 0.4.6 - parent = 2a14f3b48f79177094fc77fd3cc22bf779515ad0 + parent = 8e7d592b42004e060186f87e40504804532c9b22 From 4a2a3c8462ae42bd4972e87151d22a360c141fb2 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Tue, 30 Apr 2024 02:10:10 +0000 Subject: [PATCH 07/42] git subrepo push --branch=master noir-projects/aztec-nr subrepo: subdir: "noir-projects/aztec-nr" merged: "a27226a76" upstream: origin: "https://github.com/AztecProtocol/aztec-nr" branch: "master" commit: "a27226a76" git-subrepo: version: "0.4.6" origin: "???" commit: "???" [skip ci] --- noir-projects/aztec-nr/.gitrepo | 4 ++-- noir-projects/aztec-nr/aztec/Nargo.toml | 2 +- noir-projects/aztec-nr/tests/Nargo.toml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/noir-projects/aztec-nr/.gitrepo b/noir-projects/aztec-nr/.gitrepo index 425f564cc489..08320c53a994 100644 --- a/noir-projects/aztec-nr/.gitrepo +++ b/noir-projects/aztec-nr/.gitrepo @@ -6,7 +6,7 @@ [subrepo] remote = https://github.com/AztecProtocol/aztec-nr branch = master - commit = 071b146a0fa3951fdd05b2a2732bac331bc79f73 + commit = a27226a7655c81f7b9f3bc2457eb860100eed9cc method = merge cmdver = 0.4.6 - parent = 8e7d592b42004e060186f87e40504804532c9b22 + parent = bc2e3073176f2fe7f254e83ac082c2fad5e491f0 diff --git a/noir-projects/aztec-nr/aztec/Nargo.toml b/noir-projects/aztec-nr/aztec/Nargo.toml index 2cbb43ab2787..7a1f1af58631 100644 --- a/noir-projects/aztec-nr/aztec/Nargo.toml +++ b/noir-projects/aztec-nr/aztec/Nargo.toml @@ -5,4 +5,4 @@ compiler_version = ">=0.18.0" type = "lib" [dependencies] -protocol_types = { git="https://github.com/AztecProtocol/aztec-packages", tag="aztec-packages-v0.35.1", directory="noir-projects/noir-protocol-circuits/crates/types" } +protocol_types = { path = "../../noir-protocol-circuits/crates/types" } diff --git a/noir-projects/aztec-nr/tests/Nargo.toml b/noir-projects/aztec-nr/tests/Nargo.toml index dfed895aad0c..13404b373243 100644 --- a/noir-projects/aztec-nr/tests/Nargo.toml +++ b/noir-projects/aztec-nr/tests/Nargo.toml @@ -6,4 +6,4 @@ type = "lib" [dependencies] aztec = { path = "../aztec" } -protocol_types = { git="https://github.com/AztecProtocol/aztec-packages", tag="aztec-packages-v0.35.1", directory="noir-projects/noir-protocol-circuits/crates/types" } +protocol_types = { path = "../../noir-protocol-circuits/crates/types" } From 81142fe799338e6ed73b30eeac4468c1345f6fab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Bene=C5=A1?= Date: Tue, 30 Apr 2024 13:52:23 +0200 Subject: [PATCH 08/42] feat: `variable_base_scalar_mul` blackbox func (#6039) --- .../dsl/acir_format/acir_format.cpp | 5 + .../dsl/acir_format/acir_format.hpp | 3 + .../dsl/acir_format/acir_format.test.cpp | 6 + .../acir_format/acir_to_constraint_buf.hpp | 9 + .../acir_format/bigint_constraint.test.cpp | 5 + .../dsl/acir_format/block_constraint.test.cpp | 1 + .../dsl/acir_format/ec_operations.test.cpp | 1 + .../dsl/acir_format/ecdsa_secp256k1.test.cpp | 3 + .../dsl/acir_format/ecdsa_secp256r1.test.cpp | 4 + .../acir_format/poseidon2_constraint.test.cpp | 1 + .../acir_format/recursion_constraint.test.cpp | 2 + .../dsl/acir_format/serde/acir.hpp | 163 ++++++++++++++++++ .../acir_format/sha256_constraint.test.cpp | 1 + .../acir_format/variable_base_scalar_mul.cpp | 38 ++++ .../acir_format/variable_base_scalar_mul.hpp | 23 +++ .../crates/types/src/grumpkin_point.nr | 1 + noir/noir-repo/acvm-repo/acir/README.md | 9 + .../noir-repo/acvm-repo/acir/codegen/acir.cpp | 128 +++++++++++++- .../acir/src/circuit/black_box_functions.rs | 6 +- .../opcodes/black_box_function_call.rs | 18 ++ .../acir/tests/test_program_serialization.rs | 32 ++++ .../src/pwg/blackbox/fixed_base_scalar_mul.rs | 24 +++ .../acvm-repo/acvm/src/pwg/blackbox/mod.rs | 17 +- .../test/browser/execute_circuit.test.ts | 10 ++ .../acvm_js/test/node/execute_circuit.test.ts | 10 ++ .../test/shared/variable_base_scalar_mul.ts | 21 +++ .../src/curve_specific_solver.rs | 16 ++ .../src/fixed_base_scalar_mul.rs | 106 +++++++++--- .../bn254_blackbox_solver/src/lib.rs | 14 +- .../acvm-repo/brillig/src/black_box.rs | 8 + .../acvm-repo/brillig_vm/src/black_box.rs | 14 ++ .../brillig/brillig_gen/brillig_black_box.rs | 21 ++- .../noirc_evaluator/src/brillig/brillig_ir.rs | 10 ++ .../src/brillig/brillig_ir/debug_show.rs | 17 ++ .../ssa/acir_gen/acir_ir/generated_acir.rs | 15 +- .../src/ssa/ir/instruction/call.rs | 1 + .../cryptographic_primitives/scalar.mdx | 30 +++- .../noir_stdlib/src/grumpkin_scalar_mul.nr | 5 +- noir/noir-repo/noir_stdlib/src/scalar_mul.nr | 23 ++- .../Nargo.toml | 2 +- .../Prover.toml | 0 .../src/main.nr | 0 .../variable_base_scalar_mul/Nargo.toml | 6 + .../variable_base_scalar_mul/Prover.toml | 4 + .../variable_base_scalar_mul/src/main.nr | 33 ++++ noir/noir-repo/tooling/lsp/src/solver.rs | 10 ++ .../end-to-end/src/e2e_state_vars.test.ts | 2 +- 47 files changed, 835 insertions(+), 43 deletions(-) create mode 100644 barretenberg/cpp/src/barretenberg/dsl/acir_format/variable_base_scalar_mul.cpp create mode 100644 barretenberg/cpp/src/barretenberg/dsl/acir_format/variable_base_scalar_mul.hpp create mode 100644 noir/noir-repo/acvm-repo/acvm_js/test/shared/variable_base_scalar_mul.ts rename noir/noir-repo/test_programs/execution_success/{scalar_mul => fixed_base_scalar_mul}/Nargo.toml (63%) rename noir/noir-repo/test_programs/execution_success/{scalar_mul => fixed_base_scalar_mul}/Prover.toml (100%) rename noir/noir-repo/test_programs/execution_success/{scalar_mul => fixed_base_scalar_mul}/src/main.nr (100%) create mode 100644 noir/noir-repo/test_programs/execution_success/variable_base_scalar_mul/Nargo.toml create mode 100644 noir/noir-repo/test_programs/execution_success/variable_base_scalar_mul/Prover.toml create mode 100644 noir/noir-repo/test_programs/execution_success/variable_base_scalar_mul/src/main.nr diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.cpp index 5ece3c031d48..736990478529 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.cpp @@ -89,6 +89,11 @@ void build_constraints(Builder& builder, AcirFormat const& constraint_system, bo create_fixed_base_constraint(builder, constraint); } + // Add variable base scalar mul constraints + for (const auto& constraint : constraint_system.variable_base_scalar_mul_constraints) { + create_variable_base_constraint(builder, constraint); + } + // Add ec add constraints for (const auto& constraint : constraint_system.ec_add_constraints) { create_ec_add_constraint(builder, constraint, has_valid_witness_assignments); diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.hpp index ec82362e2d8b..a7f5b4757375 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.hpp @@ -17,6 +17,7 @@ #include "recursion_constraint.hpp" #include "schnorr_verify.hpp" #include "sha256_constraint.hpp" +#include "variable_base_scalar_mul.hpp" #include namespace acir_format { @@ -48,6 +49,7 @@ struct AcirFormat { std::vector pedersen_hash_constraints; std::vector poseidon2_constraints; std::vector fixed_base_scalar_mul_constraints; + std::vector variable_base_scalar_mul_constraints; std::vector ec_add_constraints; std::vector recursion_constraints; std::vector bigint_from_le_bytes_constraints; @@ -82,6 +84,7 @@ struct AcirFormat { pedersen_hash_constraints, poseidon2_constraints, fixed_base_scalar_mul_constraints, + variable_base_scalar_mul_constraints, ec_add_constraints, recursion_constraints, poly_triple_constraints, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.test.cpp index c7d44e319413..7de0f847b1f7 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.test.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.test.cpp @@ -48,6 +48,7 @@ TEST_F(AcirFormatTests, TestASingleConstraintNoPubInputs) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, @@ -164,6 +165,7 @@ TEST_F(AcirFormatTests, TestLogicGateFromNoirCircuit) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, @@ -232,6 +234,7 @@ TEST_F(AcirFormatTests, TestSchnorrVerifyPass) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, @@ -327,6 +330,7 @@ TEST_F(AcirFormatTests, TestSchnorrVerifySmallRange) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, @@ -441,6 +445,7 @@ TEST_F(AcirFormatTests, TestVarKeccak) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, @@ -488,6 +493,7 @@ TEST_F(AcirFormatTests, TestKeccakPermutation) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_to_constraint_buf.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_to_constraint_buf.hpp index 778363f3258c..f31fc73c806a 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_to_constraint_buf.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_to_constraint_buf.hpp @@ -317,6 +317,15 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, Aci .pub_key_x = arg.outputs[0].value, .pub_key_y = arg.outputs[1].value, }); + } else if constexpr (std::is_same_v) { + af.variable_base_scalar_mul_constraints.push_back(VariableBaseScalarMul{ + .point_x = arg.point_x.witness.value, + .point_y = arg.point_y.witness.value, + .scalar_low = arg.scalar_low.witness.value, + .scalar_high = arg.scalar_high.witness.value, + .out_point_x = arg.outputs[0].value, + .out_point_y = arg.outputs[1].value, + }); } else if constexpr (std::is_same_v) { af.ec_add_constraints.push_back(EcAdd{ .input1_x = arg.input1_x.witness.value, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/bigint_constraint.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/bigint_constraint.test.cpp index 584f7ef62a56..3b32c7f26950 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/bigint_constraint.test.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/bigint_constraint.test.cpp @@ -185,6 +185,7 @@ TEST_F(BigIntTests, TestBigIntConstraintMultiple) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, @@ -253,6 +254,7 @@ TEST_F(BigIntTests, TestBigIntConstraintSimple) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = { from_le_bytes_constraint_bigint1 }, @@ -306,6 +308,7 @@ TEST_F(BigIntTests, TestBigIntConstraintReuse) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, @@ -363,6 +366,7 @@ TEST_F(BigIntTests, TestBigIntConstraintReuse2) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, @@ -441,6 +445,7 @@ TEST_F(BigIntTests, TestBigIntDIV) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = { from_le_bytes_constraint_bigint1, from_le_bytes_constraint_bigint2 }, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/block_constraint.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/block_constraint.test.cpp index 20f9e8072bba..75b9150d335c 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/block_constraint.test.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/block_constraint.test.cpp @@ -127,6 +127,7 @@ TEST_F(UltraPlonkRAM, TestBlockConstraint) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.test.cpp index 92c76e3d7a37..bdda21409a17 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.test.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.test.cpp @@ -67,6 +67,7 @@ TEST_F(EcOperations, TestECOperations) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = { ec_add_constraint }, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/ecdsa_secp256k1.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/ecdsa_secp256k1.test.cpp index 0a11adb97be2..c494cc13e798 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/ecdsa_secp256k1.test.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/ecdsa_secp256k1.test.cpp @@ -107,6 +107,7 @@ TEST_F(ECDSASecp256k1, TestECDSAConstraintSucceed) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, @@ -156,6 +157,7 @@ TEST_F(ECDSASecp256k1, TestECDSACompilesForVerifier) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, @@ -200,6 +202,7 @@ TEST_F(ECDSASecp256k1, TestECDSAConstraintFail) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/ecdsa_secp256r1.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/ecdsa_secp256r1.test.cpp index 6cf542bc2d6e..6728445d2371 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/ecdsa_secp256r1.test.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/ecdsa_secp256r1.test.cpp @@ -141,6 +141,7 @@ TEST(ECDSASecp256r1, test_hardcoded) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, @@ -192,6 +193,7 @@ TEST(ECDSASecp256r1, TestECDSAConstraintSucceed) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, @@ -241,6 +243,7 @@ TEST(ECDSASecp256r1, TestECDSACompilesForVerifier) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, @@ -285,6 +288,7 @@ TEST(ECDSASecp256r1, TestECDSAConstraintFail) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/poseidon2_constraint.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/poseidon2_constraint.test.cpp index f509c262782a..f672505b4d70 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/poseidon2_constraint.test.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/poseidon2_constraint.test.cpp @@ -47,6 +47,7 @@ TEST_F(Poseidon2Tests, TestPoseidon2Permutation) .pedersen_hash_constraints = {}, .poseidon2_constraints = { poseidon2_constraint }, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/recursion_constraint.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/recursion_constraint.test.cpp index bbf7768abc91..97e53d30c627 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/recursion_constraint.test.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/recursion_constraint.test.cpp @@ -99,6 +99,7 @@ Builder create_inner_circuit() .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, @@ -256,6 +257,7 @@ Builder create_outer_circuit(std::vector& inner_circuits) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = recursion_constraints, .bigint_from_le_bytes_constraints = {}, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp index e2c9d020d6bc..bdfb6605ad24 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp @@ -145,6 +145,18 @@ struct BlackBoxFuncCall { static FixedBaseScalarMul bincodeDeserialize(std::vector); }; + struct VariableBaseScalarMul { + Program::FunctionInput point_x; + Program::FunctionInput point_y; + Program::FunctionInput scalar_low; + Program::FunctionInput scalar_high; + std::array outputs; + + friend bool operator==(const VariableBaseScalarMul&, const VariableBaseScalarMul&); + std::vector bincodeSerialize() const; + static VariableBaseScalarMul bincodeDeserialize(std::vector); + }; + struct EmbeddedCurveAdd { Program::FunctionInput input1_x; Program::FunctionInput input1_y; @@ -278,6 +290,7 @@ struct BlackBoxFuncCall { EcdsaSecp256k1, EcdsaSecp256r1, FixedBaseScalarMul, + VariableBaseScalarMul, EmbeddedCurveAdd, Keccak256, Keccakf1600, @@ -753,6 +766,18 @@ struct BlackBoxOp { static FixedBaseScalarMul bincodeDeserialize(std::vector); }; + struct VariableBaseScalarMul { + Program::MemoryAddress point_x; + Program::MemoryAddress point_y; + Program::MemoryAddress scalar_low; + Program::MemoryAddress scalar_high; + Program::HeapArray result; + + friend bool operator==(const VariableBaseScalarMul&, const VariableBaseScalarMul&); + std::vector bincodeSerialize() const; + static VariableBaseScalarMul bincodeDeserialize(std::vector); + }; + struct EmbeddedCurveAdd { Program::MemoryAddress input1_x; Program::MemoryAddress input1_y; @@ -855,6 +880,7 @@ struct BlackBoxOp { PedersenCommitment, PedersenHash, FixedBaseScalarMul, + VariableBaseScalarMul, EmbeddedCurveAdd, BigIntAdd, BigIntSub, @@ -3068,6 +3094,75 @@ Program::BlackBoxFuncCall::FixedBaseScalarMul serde::Deserializable< namespace Program { +inline bool operator==(const BlackBoxFuncCall::VariableBaseScalarMul& lhs, + const BlackBoxFuncCall::VariableBaseScalarMul& rhs) +{ + if (!(lhs.point_x == rhs.point_x)) { + return false; + } + if (!(lhs.point_y == rhs.point_y)) { + return false; + } + if (!(lhs.scalar_low == rhs.scalar_low)) { + return false; + } + if (!(lhs.scalar_high == rhs.scalar_high)) { + return false; + } + if (!(lhs.outputs == rhs.outputs)) { + return false; + } + return true; +} + +inline std::vector BlackBoxFuncCall::VariableBaseScalarMul::bincodeSerialize() const +{ + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); +} + +inline BlackBoxFuncCall::VariableBaseScalarMul BlackBoxFuncCall::VariableBaseScalarMul::bincodeDeserialize( + std::vector input) +{ + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw_or_abort("Some input bytes were not read"); + } + return value; +} + +} // end of namespace Program + +template <> +template +void serde::Serializable::serialize( + const Program::BlackBoxFuncCall::VariableBaseScalarMul& obj, Serializer& serializer) +{ + serde::Serializable::serialize(obj.point_x, serializer); + serde::Serializable::serialize(obj.point_y, serializer); + serde::Serializable::serialize(obj.scalar_low, serializer); + serde::Serializable::serialize(obj.scalar_high, serializer); + serde::Serializable::serialize(obj.outputs, serializer); +} + +template <> +template +Program::BlackBoxFuncCall::VariableBaseScalarMul serde::Deserializable< + Program::BlackBoxFuncCall::VariableBaseScalarMul>::deserialize(Deserializer& deserializer) +{ + Program::BlackBoxFuncCall::VariableBaseScalarMul obj; + obj.point_x = serde::Deserializable::deserialize(deserializer); + obj.point_y = serde::Deserializable::deserialize(deserializer); + obj.scalar_low = serde::Deserializable::deserialize(deserializer); + obj.scalar_high = serde::Deserializable::deserialize(deserializer); + obj.outputs = serde::Deserializable::deserialize(deserializer); + return obj; +} + +namespace Program { + inline bool operator==(const BlackBoxFuncCall::EmbeddedCurveAdd& lhs, const BlackBoxFuncCall::EmbeddedCurveAdd& rhs) { if (!(lhs.input1_x == rhs.input1_x)) { @@ -4444,6 +4539,74 @@ Program::BlackBoxOp::FixedBaseScalarMul serde::Deserializable BlackBoxOp::VariableBaseScalarMul::bincodeSerialize() const +{ + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); +} + +inline BlackBoxOp::VariableBaseScalarMul BlackBoxOp::VariableBaseScalarMul::bincodeDeserialize( + std::vector input) +{ + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw_or_abort("Some input bytes were not read"); + } + return value; +} + +} // end of namespace Program + +template <> +template +void serde::Serializable::serialize( + const Program::BlackBoxOp::VariableBaseScalarMul& obj, Serializer& serializer) +{ + serde::Serializable::serialize(obj.point_x, serializer); + serde::Serializable::serialize(obj.point_y, serializer); + serde::Serializable::serialize(obj.scalar_low, serializer); + serde::Serializable::serialize(obj.scalar_high, serializer); + serde::Serializable::serialize(obj.result, serializer); +} + +template <> +template +Program::BlackBoxOp::VariableBaseScalarMul serde::Deserializable< + Program::BlackBoxOp::VariableBaseScalarMul>::deserialize(Deserializer& deserializer) +{ + Program::BlackBoxOp::VariableBaseScalarMul obj; + obj.point_x = serde::Deserializable::deserialize(deserializer); + obj.point_y = serde::Deserializable::deserialize(deserializer); + obj.scalar_low = serde::Deserializable::deserialize(deserializer); + obj.scalar_high = serde::Deserializable::deserialize(deserializer); + obj.result = serde::Deserializable::deserialize(deserializer); + return obj; +} + +namespace Program { + inline bool operator==(const BlackBoxOp::EmbeddedCurveAdd& lhs, const BlackBoxOp::EmbeddedCurveAdd& rhs) { if (!(lhs.input1_x == rhs.input1_x)) { diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/sha256_constraint.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/sha256_constraint.test.cpp index 6266253ee552..aa520806b2b7 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/sha256_constraint.test.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/sha256_constraint.test.cpp @@ -49,6 +49,7 @@ TEST_F(Sha256Tests, TestSha256Compression) .pedersen_hash_constraints = {}, .poseidon2_constraints = {}, .fixed_base_scalar_mul_constraints = {}, + .variable_base_scalar_mul_constraints = {}, .ec_add_constraints = {}, .recursion_constraints = {}, .bigint_from_le_bytes_constraints = {}, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/variable_base_scalar_mul.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/variable_base_scalar_mul.cpp new file mode 100644 index 000000000000..6446b68158c4 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/variable_base_scalar_mul.cpp @@ -0,0 +1,38 @@ +#include "variable_base_scalar_mul.hpp" +#include "barretenberg/dsl/types.hpp" +#include "barretenberg/ecc/curves/bn254/fr.hpp" +#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" +#include "barretenberg/plonk_honk_shared/arithmetization/gate_data.hpp" + +namespace acir_format { + +template void create_variable_base_constraint(Builder& builder, const VariableBaseScalarMul& input) +{ + using cycle_group_ct = bb::stdlib::cycle_group; + using cycle_scalar_ct = typename bb::stdlib::cycle_group::cycle_scalar; + using field_ct = bb::stdlib::field_t; + + // We instantiate the input point/variable base as `cycle_group_ct` + auto point_x = field_ct::from_witness_index(&builder, input.point_x); + auto point_y = field_ct::from_witness_index(&builder, input.point_y); + cycle_group_ct input_point(point_x, point_y, false); + + // We reconstruct the scalar from the low and high limbs + field_ct scalar_low_as_field = field_ct::from_witness_index(&builder, input.scalar_low); + field_ct scalar_high_as_field = field_ct::from_witness_index(&builder, input.scalar_high); + cycle_scalar_ct scalar(scalar_low_as_field, scalar_high_as_field); + + // We multiply the scalar with input point/variable base to get the result + auto result = input_point * scalar; + + // Finally we add the constraints + builder.assert_equal(result.x.get_witness_index(), input.out_point_x); + builder.assert_equal(result.y.get_witness_index(), input.out_point_y); +} + +template void create_variable_base_constraint(UltraCircuitBuilder& builder, + const VariableBaseScalarMul& input); +template void create_variable_base_constraint(GoblinUltraCircuitBuilder& builder, + const VariableBaseScalarMul& input); + +} // namespace acir_format diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/variable_base_scalar_mul.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/variable_base_scalar_mul.hpp new file mode 100644 index 000000000000..d903df2cb322 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/variable_base_scalar_mul.hpp @@ -0,0 +1,23 @@ +#pragma once +#include "barretenberg/dsl/types.hpp" +#include "barretenberg/serialize/msgpack.hpp" +#include + +namespace acir_format { + +struct VariableBaseScalarMul { + uint32_t point_x; + uint32_t point_y; + uint32_t scalar_low; + uint32_t scalar_high; + uint32_t out_point_x; + uint32_t out_point_y; + + // for serialization, update with any new fields + MSGPACK_FIELDS(point_x, point_y, scalar_low, scalar_high, out_point_x, out_point_y); + friend bool operator==(VariableBaseScalarMul const& lhs, VariableBaseScalarMul const& rhs) = default; +}; + +template void create_variable_base_constraint(Builder& builder, const VariableBaseScalarMul& input); + +} // namespace acir_format diff --git a/noir-projects/noir-protocol-circuits/crates/types/src/grumpkin_point.nr b/noir-projects/noir-protocol-circuits/crates/types/src/grumpkin_point.nr index 965654ac144e..ec334d78b50f 100644 --- a/noir-projects/noir-protocol-circuits/crates/types/src/grumpkin_point.nr +++ b/noir-projects/noir-protocol-circuits/crates/types/src/grumpkin_point.nr @@ -3,6 +3,7 @@ use dep::std::cmp::Eq; global GRUMPKIN_POINT_SERIALIZED_LEN: Field = 2; +// TODO(https://github.com/noir-lang/noir/issues/4931) struct GrumpkinPoint { x: Field, y: Field, diff --git a/noir/noir-repo/acvm-repo/acir/README.md b/noir/noir-repo/acvm-repo/acir/README.md index 801aeac1140a..e72f7ea178d8 100644 --- a/noir/noir-repo/acvm-repo/acir/README.md +++ b/noir/noir-repo/acvm-repo/acir/README.md @@ -146,6 +146,15 @@ Inputs and outputs are similar to SchnorrVerify, except that because we use a di Because the Grumpkin scalar field is bigger than the ACIR field, we provide 2 ACIR fields representing the low and high parts of the Grumpkin scalar $a$: $a=low+high*2^{128},$ with $low, high < 2^{128}$ +**VariableBaseScalarMul**: scalar multiplication with a variable base/input point (P) of the embedded curve +- input: + point_x, point_y representing x and y coordinates of input point P + scalar_low, scalar_high are 2 (field , 254), representing the low and high part of the input scalar. For Barretenberg, they must both be less than 128 bits. +- output: x and y coordinates of $low*P+high*2^{128}*P$, where P is the input point P + +Because the Grumpkin scalar field is bigger than the ACIR field, we provide 2 ACIR fields representing the low and high parts of the Grumpkin scalar $a$: +$a=low+high*2^{128},$ with $low, high < 2^{128}$ + **Keccak256**: Computes the Keccak-256 (Ethereum version) of the inputs. - inputs: Vector of bytes (FieldElement, 8) - outputs: Vector of 32 bytes (FieldElement, 8) diff --git a/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp b/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp index 0ad193fedf65..1e5207c01cbb 100644 --- a/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp +++ b/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp @@ -145,6 +145,18 @@ namespace Program { static FixedBaseScalarMul bincodeDeserialize(std::vector); }; + struct VariableBaseScalarMul { + Program::FunctionInput point_x; + Program::FunctionInput point_y; + Program::FunctionInput scalar_low; + Program::FunctionInput scalar_high; + std::array outputs; + + friend bool operator==(const VariableBaseScalarMul&, const VariableBaseScalarMul&); + std::vector bincodeSerialize() const; + static VariableBaseScalarMul bincodeDeserialize(std::vector); + }; + struct EmbeddedCurveAdd { Program::FunctionInput input1_x; Program::FunctionInput input1_y; @@ -266,7 +278,7 @@ namespace Program { static Sha256Compression bincodeDeserialize(std::vector); }; - std::variant value; + std::variant value; friend bool operator==(const BlackBoxFuncCall&, const BlackBoxFuncCall&); std::vector bincodeSerialize() const; @@ -729,6 +741,18 @@ namespace Program { static FixedBaseScalarMul bincodeDeserialize(std::vector); }; + struct VariableBaseScalarMul { + Program::MemoryAddress point_x; + Program::MemoryAddress point_y; + Program::MemoryAddress scalar_low; + Program::MemoryAddress scalar_high; + Program::HeapArray result; + + friend bool operator==(const VariableBaseScalarMul&, const VariableBaseScalarMul&); + std::vector bincodeSerialize() const; + static VariableBaseScalarMul bincodeDeserialize(std::vector); + }; + struct EmbeddedCurveAdd { Program::MemoryAddress input1_x; Program::MemoryAddress input1_y; @@ -820,7 +844,7 @@ namespace Program { static Sha256Compression bincodeDeserialize(std::vector); }; - std::variant value; + std::variant value; friend bool operator==(const BlackBoxOp&, const BlackBoxOp&); std::vector bincodeSerialize() const; @@ -2690,6 +2714,56 @@ Program::BlackBoxFuncCall::FixedBaseScalarMul serde::Deserializable BlackBoxFuncCall::VariableBaseScalarMul::bincodeSerialize() const { + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); + } + + inline BlackBoxFuncCall::VariableBaseScalarMul BlackBoxFuncCall::VariableBaseScalarMul::bincodeDeserialize(std::vector input) { + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw serde::deserialization_error("Some input bytes were not read"); + } + return value; + } + +} // end of namespace Program + +template <> +template +void serde::Serializable::serialize(const Program::BlackBoxFuncCall::VariableBaseScalarMul &obj, Serializer &serializer) { + serde::Serializable::serialize(obj.point_x, serializer); + serde::Serializable::serialize(obj.point_y, serializer); + serde::Serializable::serialize(obj.scalar_low, serializer); + serde::Serializable::serialize(obj.scalar_high, serializer); + serde::Serializable::serialize(obj.outputs, serializer); +} + +template <> +template +Program::BlackBoxFuncCall::VariableBaseScalarMul serde::Deserializable::deserialize(Deserializer &deserializer) { + Program::BlackBoxFuncCall::VariableBaseScalarMul obj; + obj.point_x = serde::Deserializable::deserialize(deserializer); + obj.point_y = serde::Deserializable::deserialize(deserializer); + obj.scalar_low = serde::Deserializable::deserialize(deserializer); + obj.scalar_high = serde::Deserializable::deserialize(deserializer); + obj.outputs = serde::Deserializable::deserialize(deserializer); + return obj; +} + namespace Program { inline bool operator==(const BlackBoxFuncCall::EmbeddedCurveAdd &lhs, const BlackBoxFuncCall::EmbeddedCurveAdd &rhs) { @@ -3750,6 +3824,56 @@ Program::BlackBoxOp::FixedBaseScalarMul serde::Deserializable BlackBoxOp::VariableBaseScalarMul::bincodeSerialize() const { + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); + } + + inline BlackBoxOp::VariableBaseScalarMul BlackBoxOp::VariableBaseScalarMul::bincodeDeserialize(std::vector input) { + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw serde::deserialization_error("Some input bytes were not read"); + } + return value; + } + +} // end of namespace Program + +template <> +template +void serde::Serializable::serialize(const Program::BlackBoxOp::VariableBaseScalarMul &obj, Serializer &serializer) { + serde::Serializable::serialize(obj.point_x, serializer); + serde::Serializable::serialize(obj.point_y, serializer); + serde::Serializable::serialize(obj.scalar_low, serializer); + serde::Serializable::serialize(obj.scalar_high, serializer); + serde::Serializable::serialize(obj.result, serializer); +} + +template <> +template +Program::BlackBoxOp::VariableBaseScalarMul serde::Deserializable::deserialize(Deserializer &deserializer) { + Program::BlackBoxOp::VariableBaseScalarMul obj; + obj.point_x = serde::Deserializable::deserialize(deserializer); + obj.point_y = serde::Deserializable::deserialize(deserializer); + obj.scalar_low = serde::Deserializable::deserialize(deserializer); + obj.scalar_high = serde::Deserializable::deserialize(deserializer); + obj.result = serde::Deserializable::deserialize(deserializer); + return obj; +} + namespace Program { inline bool operator==(const BlackBoxOp::EmbeddedCurveAdd &lhs, const BlackBoxOp::EmbeddedCurveAdd &rhs) { diff --git a/noir/noir-repo/acvm-repo/acir/src/circuit/black_box_functions.rs b/noir/noir-repo/acvm-repo/acir/src/circuit/black_box_functions.rs index 0a7ee244a5ee..9a43702a408c 100644 --- a/noir/noir-repo/acvm-repo/acir/src/circuit/black_box_functions.rs +++ b/noir/noir-repo/acvm-repo/acir/src/circuit/black_box_functions.rs @@ -36,8 +36,10 @@ pub enum BlackBoxFunc { EcdsaSecp256k1, /// Verifies a ECDSA signature over the secp256r1 curve. EcdsaSecp256r1, - /// Performs scalar multiplication over the embedded curve on which [`FieldElement`][acir_field::FieldElement] is defined. + /// Performs scalar multiplication over the embedded curve on which [`FieldElement`][acir_field::FieldElement] is defined and a fixed base/generator point G1. FixedBaseScalarMul, + /// Performs scalar multiplication over the embedded curve on which [`FieldElement`][acir_field::FieldElement] is defined and a variable base/input point P. + VariableBaseScalarMul, /// Calculates the Keccak256 hash of the inputs. Keccak256, /// Keccak Permutation function of 1600 width @@ -82,6 +84,7 @@ impl BlackBoxFunc { BlackBoxFunc::PedersenHash => "pedersen_hash", BlackBoxFunc::EcdsaSecp256k1 => "ecdsa_secp256k1", BlackBoxFunc::FixedBaseScalarMul => "fixed_base_scalar_mul", + BlackBoxFunc::VariableBaseScalarMul => "variable_base_scalar_mul", BlackBoxFunc::EmbeddedCurveAdd => "embedded_curve_add", BlackBoxFunc::AND => "and", BlackBoxFunc::XOR => "xor", @@ -112,6 +115,7 @@ impl BlackBoxFunc { "ecdsa_secp256k1" => Some(BlackBoxFunc::EcdsaSecp256k1), "ecdsa_secp256r1" => Some(BlackBoxFunc::EcdsaSecp256r1), "fixed_base_scalar_mul" => Some(BlackBoxFunc::FixedBaseScalarMul), + "variable_base_scalar_mul" => Some(BlackBoxFunc::VariableBaseScalarMul), "embedded_curve_add" => Some(BlackBoxFunc::EmbeddedCurveAdd), "and" => Some(BlackBoxFunc::AND), "xor" => Some(BlackBoxFunc::XOR), diff --git a/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs b/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs index 405cd0cef007..5715019937c2 100644 --- a/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs +++ b/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs @@ -85,6 +85,13 @@ pub enum BlackBoxFuncCall { high: FunctionInput, outputs: (Witness, Witness), }, + VariableBaseScalarMul { + point_x: FunctionInput, + point_y: FunctionInput, + scalar_low: FunctionInput, + scalar_high: FunctionInput, + outputs: (Witness, Witness), + }, EmbeddedCurveAdd { input1_x: FunctionInput, input1_y: FunctionInput, @@ -189,6 +196,7 @@ impl BlackBoxFuncCall { BlackBoxFuncCall::EcdsaSecp256k1 { .. } => BlackBoxFunc::EcdsaSecp256k1, BlackBoxFuncCall::EcdsaSecp256r1 { .. } => BlackBoxFunc::EcdsaSecp256r1, BlackBoxFuncCall::FixedBaseScalarMul { .. } => BlackBoxFunc::FixedBaseScalarMul, + BlackBoxFuncCall::VariableBaseScalarMul { .. } => BlackBoxFunc::VariableBaseScalarMul, BlackBoxFuncCall::EmbeddedCurveAdd { .. } => BlackBoxFunc::EmbeddedCurveAdd, BlackBoxFuncCall::Keccak256 { .. } => BlackBoxFunc::Keccak256, BlackBoxFuncCall::Keccakf1600 { .. } => BlackBoxFunc::Keccakf1600, @@ -232,6 +240,15 @@ impl BlackBoxFuncCall { | BlackBoxFuncCall::BigIntDiv { .. } | BlackBoxFuncCall::BigIntToLeBytes { .. } => Vec::new(), BlackBoxFuncCall::FixedBaseScalarMul { low, high, .. } => vec![*low, *high], + BlackBoxFuncCall::VariableBaseScalarMul { + point_x, + point_y, + scalar_low, + scalar_high, + .. + } => { + vec![*point_x, *point_y, *scalar_low, *scalar_high] + } BlackBoxFuncCall::EmbeddedCurveAdd { input1_x, input1_y, input2_x, input2_y, .. } => vec![*input1_x, *input1_y, *input2_x, *input2_y], @@ -329,6 +346,7 @@ impl BlackBoxFuncCall { | BlackBoxFuncCall::PedersenHash { output, .. } | BlackBoxFuncCall::EcdsaSecp256r1 { output, .. } => vec![*output], BlackBoxFuncCall::FixedBaseScalarMul { outputs, .. } + | BlackBoxFuncCall::VariableBaseScalarMul { outputs, .. } | BlackBoxFuncCall::PedersenCommitment { outputs, .. } | BlackBoxFuncCall::EmbeddedCurveAdd { outputs, .. } => vec![outputs.0, outputs.1], BlackBoxFuncCall::RANGE { .. } diff --git a/noir/noir-repo/acvm-repo/acir/tests/test_program_serialization.rs b/noir/noir-repo/acvm-repo/acir/tests/test_program_serialization.rs index c5912b61cf15..2ad082410a1b 100644 --- a/noir/noir-repo/acvm-repo/acir/tests/test_program_serialization.rs +++ b/noir/noir-repo/acvm-repo/acir/tests/test_program_serialization.rs @@ -85,6 +85,38 @@ fn fixed_base_scalar_mul_circuit() { assert_eq!(bytes, expected_serialization) } +#[test] +fn variable_base_scalar_mul_circuit() { + let variable_base_scalar_mul = + Opcode::BlackBoxFuncCall(BlackBoxFuncCall::VariableBaseScalarMul { + point_x: FunctionInput { witness: Witness(1), num_bits: 128 }, + point_y: FunctionInput { witness: Witness(2), num_bits: 128 }, + scalar_low: FunctionInput { witness: Witness(3), num_bits: 128 }, + scalar_high: FunctionInput { witness: Witness(4), num_bits: 128 }, + outputs: (Witness(5), Witness(6)), + }); + + let circuit = Circuit { + current_witness_index: 7, + opcodes: vec![variable_base_scalar_mul], + private_parameters: BTreeSet::from([Witness(1), Witness(2), Witness(3), Witness(4)]), + return_values: PublicInputs(BTreeSet::from_iter(vec![Witness(5), Witness(6)])), + ..Circuit::default() + }; + let program = Program { functions: vec![circuit], unconstrained_functions: vec![] }; + + let bytes = Program::serialize_program(&program); + + let expected_serialization: Vec = vec![ + 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 93, 139, 65, 10, 0, 32, 8, 4, 213, 172, 46, 61, 186, + 167, 103, 52, 65, 185, 176, 140, 44, 142, 202, 73, 143, 42, 247, 230, 128, 51, 106, 176, + 64, 135, 53, 218, 112, 252, 113, 141, 223, 187, 9, 155, 36, 231, 203, 2, 176, 218, 19, 62, + 137, 0, 0, 0, + ]; + + assert_eq!(bytes, expected_serialization) +} + #[test] fn pedersen_circuit() { let pedersen = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::PedersenCommitment { diff --git a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/fixed_base_scalar_mul.rs b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/fixed_base_scalar_mul.rs index c5bfd1d5646d..79e33ae8de53 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/fixed_base_scalar_mul.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/fixed_base_scalar_mul.rs @@ -1,3 +1,4 @@ +// TODO(https://github.com/noir-lang/noir/issues/4932): rename this file to something more generic use acir::{ circuit::opcodes::FunctionInput, native_types::{Witness, WitnessMap}, @@ -24,6 +25,29 @@ pub(super) fn fixed_base_scalar_mul( Ok(()) } +pub(super) fn variable_base_scalar_mul( + backend: &impl BlackBoxFunctionSolver, + initial_witness: &mut WitnessMap, + point_x: FunctionInput, + point_y: FunctionInput, + scalar_low: FunctionInput, + scalar_high: FunctionInput, + outputs: (Witness, Witness), +) -> Result<(), OpcodeResolutionError> { + let point_x = witness_to_value(initial_witness, point_x.witness)?; + let point_y = witness_to_value(initial_witness, point_y.witness)?; + let scalar_low = witness_to_value(initial_witness, scalar_low.witness)?; + let scalar_high = witness_to_value(initial_witness, scalar_high.witness)?; + + let (out_point_x, out_point_y) = + backend.variable_base_scalar_mul(point_x, point_y, scalar_low, scalar_high)?; + + insert_value(&outputs.0, out_point_x, initial_witness)?; + insert_value(&outputs.1, out_point_y, initial_witness)?; + + Ok(()) +} + pub(super) fn embedded_curve_add( backend: &impl BlackBoxFunctionSolver, initial_witness: &mut WitnessMap, diff --git a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/mod.rs b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/mod.rs index 2753c7baaaa9..2487d511b502 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/mod.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/mod.rs @@ -20,7 +20,7 @@ mod pedersen; mod range; mod signature; -use fixed_base_scalar_mul::{embedded_curve_add, fixed_base_scalar_mul}; +use fixed_base_scalar_mul::{embedded_curve_add, fixed_base_scalar_mul, variable_base_scalar_mul}; // Hash functions should eventually be exposed for external consumers. use hash::{solve_generic_256_hash_opcode, solve_sha_256_permutation_opcode}; use logic::{and, xor}; @@ -158,6 +158,21 @@ pub(crate) fn solve( BlackBoxFuncCall::FixedBaseScalarMul { low, high, outputs } => { fixed_base_scalar_mul(backend, initial_witness, *low, *high, *outputs) } + BlackBoxFuncCall::VariableBaseScalarMul { + point_x, + point_y, + scalar_low, + scalar_high, + outputs, + } => variable_base_scalar_mul( + backend, + initial_witness, + *point_x, + *point_y, + *scalar_low, + *scalar_high, + *outputs, + ), BlackBoxFuncCall::EmbeddedCurveAdd { input1_x, input1_y, input2_x, input2_y, outputs } => { embedded_curve_add( backend, diff --git a/noir/noir-repo/acvm-repo/acvm_js/test/browser/execute_circuit.test.ts b/noir/noir-repo/acvm-repo/acvm_js/test/browser/execute_circuit.test.ts index 259c51ed1c62..f6287c2ae8a1 100644 --- a/noir/noir-repo/acvm-repo/acvm_js/test/browser/execute_circuit.test.ts +++ b/noir/noir-repo/acvm-repo/acvm_js/test/browser/execute_circuit.test.ts @@ -103,6 +103,16 @@ it('successfully executes a FixedBaseScalarMul opcode', async () => { expect(solvedWitness).to.be.deep.eq(expectedWitnessMap); }); +it('successfully executes a VariableBaseScalarMul opcode', async () => { + const { bytecode, initialWitnessMap, expectedWitnessMap } = await import('../shared/variable_base_scalar_mul'); + + const solvedWitness: WitnessMap = await executeCircuit(bytecode, initialWitnessMap, () => { + throw Error('unexpected oracle'); + }); + + expect(solvedWitness).to.be.deep.eq(expectedWitnessMap); +}); + it('successfully executes a SchnorrVerify opcode', async () => { const { bytecode, initialWitnessMap, expectedWitnessMap } = await import('../shared/schnorr_verify'); diff --git a/noir/noir-repo/acvm-repo/acvm_js/test/node/execute_circuit.test.ts b/noir/noir-repo/acvm-repo/acvm_js/test/node/execute_circuit.test.ts index 32487f8bbbac..f9fd5c10b3ee 100644 --- a/noir/noir-repo/acvm-repo/acvm_js/test/node/execute_circuit.test.ts +++ b/noir/noir-repo/acvm-repo/acvm_js/test/node/execute_circuit.test.ts @@ -100,6 +100,16 @@ it('successfully executes a FixedBaseScalarMul opcode', async () => { expect(solvedWitness).to.be.deep.eq(expectedWitnessMap); }); +it('successfully executes a VariableBaseScalarMul opcode', async () => { + const { bytecode, initialWitnessMap, expectedWitnessMap } = await import('../shared/variable_base_scalar_mul'); + + const solvedWitness: WitnessMap = await executeCircuit(bytecode, initialWitnessMap, () => { + throw Error('unexpected oracle'); + }); + + expect(solvedWitness).to.be.deep.eq(expectedWitnessMap); +}); + it('successfully executes a SchnorrVerify opcode', async () => { const { bytecode, initialWitnessMap, expectedWitnessMap } = await import('../shared/schnorr_verify'); diff --git a/noir/noir-repo/acvm-repo/acvm_js/test/shared/variable_base_scalar_mul.ts b/noir/noir-repo/acvm-repo/acvm_js/test/shared/variable_base_scalar_mul.ts new file mode 100644 index 000000000000..400f7bf4e614 --- /dev/null +++ b/noir/noir-repo/acvm-repo/acvm_js/test/shared/variable_base_scalar_mul.ts @@ -0,0 +1,21 @@ +// See `variable_base_scalar_mul_circuit` integration test in `acir/tests/test_program_serialization.rs`. +export const bytecode = Uint8Array.from([ + 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 93, 139, 65, 10, 0, 32, 8, 4, 213, 172, 46, 61, 186, 167, 103, 52, 65, 185, 176, + 140, 44, 142, 202, 73, 143, 42, 247, 230, 128, 51, 106, 176, 64, 135, 53, 218, 112, 252, 113, 141, 223, 187, 9, 155, + 36, 231, 203, 2, 176, 218, 19, 62, 137, 0, 0, 0, +]); +export const initialWitnessMap = new Map([ + [1, '0x0000000000000000000000000000000000000000000000000000000000000001'], + [2, '0x0000000000000002cf135e7506a45d632d270d45f1181294833fc48d823f272c'], + [3, '0x0000000000000000000000000000000000000000000000000000000000000001'], + [4, '0x0000000000000000000000000000000000000000000000000000000000000000'], +]); + +export const expectedWitnessMap = new Map([ + [1, '0x0000000000000000000000000000000000000000000000000000000000000001'], + [2, '0x0000000000000002cf135e7506a45d632d270d45f1181294833fc48d823f272c'], + [3, '0x0000000000000000000000000000000000000000000000000000000000000001'], + [4, '0x0000000000000000000000000000000000000000000000000000000000000000'], + [5, '0x0000000000000000000000000000000000000000000000000000000000000001'], + [6, '0x0000000000000002cf135e7506a45d632d270d45f1181294833fc48d823f272c'], +]); diff --git a/noir/noir-repo/acvm-repo/blackbox_solver/src/curve_specific_solver.rs b/noir/noir-repo/acvm-repo/blackbox_solver/src/curve_specific_solver.rs index fab67467d9ab..a809e21e2ca9 100644 --- a/noir/noir-repo/acvm-repo/blackbox_solver/src/curve_specific_solver.rs +++ b/noir/noir-repo/acvm-repo/blackbox_solver/src/curve_specific_solver.rs @@ -29,6 +29,13 @@ pub trait BlackBoxFunctionSolver { low: &FieldElement, high: &FieldElement, ) -> Result<(FieldElement, FieldElement), BlackBoxResolutionError>; + fn variable_base_scalar_mul( + &self, + point_x: &FieldElement, + point_y: &FieldElement, + scalar_low: &FieldElement, + scalar_high: &FieldElement, + ) -> Result<(FieldElement, FieldElement), BlackBoxResolutionError>; fn ec_add( &self, input1_x: &FieldElement, @@ -85,6 +92,15 @@ impl BlackBoxFunctionSolver for StubbedBlackBoxSolver { ) -> Result<(FieldElement, FieldElement), BlackBoxResolutionError> { Err(Self::fail(BlackBoxFunc::FixedBaseScalarMul)) } + fn variable_base_scalar_mul( + &self, + _point_x: &FieldElement, + _point_y: &FieldElement, + _scalar_low: &FieldElement, + _scalar_high: &FieldElement, + ) -> Result<(FieldElement, FieldElement), BlackBoxResolutionError> { + Err(Self::fail(BlackBoxFunc::VariableBaseScalarMul)) + } fn ec_add( &self, _input1_x: &FieldElement, diff --git a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/fixed_base_scalar_mul.rs b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/fixed_base_scalar_mul.rs index cd91c290f494..2d7ffe1cf1c8 100644 --- a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/fixed_base_scalar_mul.rs +++ b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/fixed_base_scalar_mul.rs @@ -1,3 +1,4 @@ +// TODO(https://github.com/noir-lang/noir/issues/4932): rename this file to something more generic use ark_ec::AffineRepr; use ark_ff::MontConfig; use num_bigint::BigUint; @@ -6,40 +7,59 @@ use acir::{BlackBoxFunc, FieldElement}; use crate::BlackBoxResolutionError; +/// Performs fixed-base scalar multiplication using the curve's generator point. pub fn fixed_base_scalar_mul( low: &FieldElement, high: &FieldElement, ) -> Result<(FieldElement, FieldElement), BlackBoxResolutionError> { - let low: u128 = low.try_into_u128().ok_or_else(|| { + let generator = grumpkin::SWAffine::generator(); + let generator_x = FieldElement::from_repr(*generator.x().unwrap()); + let generator_y = FieldElement::from_repr(*generator.y().unwrap()); + + variable_base_scalar_mul(&generator_x, &generator_y, low, high).map_err(|err| match err { + BlackBoxResolutionError::Failed(_, message) => { + BlackBoxResolutionError::Failed(BlackBoxFunc::FixedBaseScalarMul, message) + } + }) +} + +pub fn variable_base_scalar_mul( + point_x: &FieldElement, + point_y: &FieldElement, + scalar_low: &FieldElement, + scalar_high: &FieldElement, +) -> Result<(FieldElement, FieldElement), BlackBoxResolutionError> { + let point1 = create_point(*point_x, *point_y) + .map_err(|e| BlackBoxResolutionError::Failed(BlackBoxFunc::VariableBaseScalarMul, e))?; + + let scalar_low: u128 = scalar_low.try_into_u128().ok_or_else(|| { BlackBoxResolutionError::Failed( - BlackBoxFunc::FixedBaseScalarMul, - format!("Limb {} is not less than 2^128", low.to_hex()), + BlackBoxFunc::VariableBaseScalarMul, + format!("Limb {} is not less than 2^128", scalar_low.to_hex()), ) })?; - let high: u128 = high.try_into_u128().ok_or_else(|| { + let scalar_high: u128 = scalar_high.try_into_u128().ok_or_else(|| { BlackBoxResolutionError::Failed( - BlackBoxFunc::FixedBaseScalarMul, - format!("Limb {} is not less than 2^128", high.to_hex()), + BlackBoxFunc::VariableBaseScalarMul, + format!("Limb {} is not less than 2^128", scalar_high.to_hex()), ) })?; - let mut bytes = high.to_be_bytes().to_vec(); - bytes.extend_from_slice(&low.to_be_bytes()); + let mut bytes = scalar_high.to_be_bytes().to_vec(); + bytes.extend_from_slice(&scalar_low.to_be_bytes()); // Check if this is smaller than the grumpkin modulus let grumpkin_integer = BigUint::from_bytes_be(&bytes); if grumpkin_integer >= grumpkin::FrConfig::MODULUS.into() { return Err(BlackBoxResolutionError::Failed( - BlackBoxFunc::FixedBaseScalarMul, + BlackBoxFunc::VariableBaseScalarMul, format!("{} is not a valid grumpkin scalar", grumpkin_integer.to_str_radix(16)), )); } - let result = grumpkin::SWAffine::from( - grumpkin::SWAffine::generator().mul_bigint(grumpkin_integer.to_u64_digits()), - ); + let result = grumpkin::SWAffine::from(point1.mul_bigint(grumpkin_integer.to_u64_digits())); if let Some((res_x, res_y)) = result.xy() { Ok((FieldElement::from_repr(*res_x), FieldElement::from_repr(*res_y))) } else { @@ -47,17 +67,6 @@ pub fn fixed_base_scalar_mul( } } -fn create_point(x: FieldElement, y: FieldElement) -> Result { - let point = grumpkin::SWAffine::new_unchecked(x.into_repr(), y.into_repr()); - if !point.is_on_curve() { - return Err(format!("Point ({}, {}) is not on curve", x.to_hex(), y.to_hex())); - }; - if !point.is_in_correct_subgroup_assuming_on_curve() { - return Err(format!("Point ({}, {}) is not in correct subgroup", x.to_hex(), y.to_hex())); - }; - Ok(point) -} - pub fn embedded_curve_add( input1_x: FieldElement, input1_y: FieldElement, @@ -79,6 +88,17 @@ pub fn embedded_curve_add( } } +fn create_point(x: FieldElement, y: FieldElement) -> Result { + let point = grumpkin::SWAffine::new_unchecked(x.into_repr(), y.into_repr()); + if !point.is_on_curve() { + return Err(format!("Point ({}, {}) is not on curve", x.to_hex(), y.to_hex())); + }; + if !point.is_in_correct_subgroup_assuming_on_curve() { + return Err(format!("Point ({}, {}) is not in correct subgroup", x.to_hex(), y.to_hex())); + }; + Ok(point) +} + #[cfg(test)] mod grumpkin_fixed_base_scalar_mul { use ark_ff::BigInteger; @@ -147,6 +167,46 @@ mod grumpkin_fixed_base_scalar_mul { ); } + #[test] + fn variable_base_matches_fixed_base_for_generator_on_input( + ) -> Result<(), BlackBoxResolutionError> { + let low = FieldElement::one(); + let high = FieldElement::from(2u128); + + let generator = grumpkin::SWAffine::generator(); + let generator_x = FieldElement::from_repr(*generator.x().unwrap()); + let generator_y = FieldElement::from_repr(*generator.y().unwrap()); + + let fixed_res = fixed_base_scalar_mul(&low, &high)?; + let variable_res = variable_base_scalar_mul(&generator_x, &generator_y, &low, &high)?; + + assert_eq!(fixed_res, variable_res); + Ok(()) + } + + #[test] + fn variable_base_scalar_mul_rejects_invalid_point() { + let invalid_point_x = FieldElement::one(); + let invalid_point_y = FieldElement::one(); + let valid_scalar_low = FieldElement::zero(); + let valid_scalar_high = FieldElement::zero(); + + let res = variable_base_scalar_mul( + &invalid_point_x, + &invalid_point_y, + &valid_scalar_low, + &valid_scalar_high, + ); + + assert_eq!( + res, + Err(BlackBoxResolutionError::Failed( + BlackBoxFunc::VariableBaseScalarMul, + "Point (0000000000000000000000000000000000000000000000000000000000000001, 0000000000000000000000000000000000000000000000000000000000000001) is not on curve".into(), + )) + ); + } + #[test] fn rejects_addition_of_points_not_in_curve() { let x = FieldElement::from(1u128); diff --git a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/lib.rs b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/lib.rs index 25b10252a784..9395260fe36a 100644 --- a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/lib.rs +++ b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/lib.rs @@ -9,7 +9,9 @@ mod fixed_base_scalar_mul; mod poseidon2; mod wasm; -pub use fixed_base_scalar_mul::{embedded_curve_add, fixed_base_scalar_mul}; +pub use fixed_base_scalar_mul::{ + embedded_curve_add, fixed_base_scalar_mul, variable_base_scalar_mul, +}; pub use poseidon2::poseidon2_permutation; use wasm::Barretenberg; @@ -97,6 +99,16 @@ impl BlackBoxFunctionSolver for Bn254BlackBoxSolver { fixed_base_scalar_mul(low, high) } + fn variable_base_scalar_mul( + &self, + point_x: &FieldElement, + point_y: &FieldElement, + low: &FieldElement, + high: &FieldElement, + ) -> Result<(FieldElement, FieldElement), BlackBoxResolutionError> { + variable_base_scalar_mul(point_x, point_y, low, high) + } + fn ec_add( &self, input1_x: &FieldElement, diff --git a/noir/noir-repo/acvm-repo/brillig/src/black_box.rs b/noir/noir-repo/acvm-repo/brillig/src/black_box.rs index 29861d0fd841..f31a434c7725 100644 --- a/noir/noir-repo/acvm-repo/brillig/src/black_box.rs +++ b/noir/noir-repo/acvm-repo/brillig/src/black_box.rs @@ -72,6 +72,14 @@ pub enum BlackBoxOp { high: MemoryAddress, result: HeapArray, }, + /// Performs scalar multiplication over the embedded curve with variable base point. + VariableBaseScalarMul { + point_x: MemoryAddress, + point_y: MemoryAddress, + scalar_low: MemoryAddress, + scalar_high: MemoryAddress, + result: HeapArray, + }, /// Performs addition over the embedded curve. EmbeddedCurveAdd { input1_x: MemoryAddress, diff --git a/noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs b/noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs index 19407da52dbe..9557cdae7b91 100644 --- a/noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs +++ b/noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs @@ -143,6 +143,19 @@ pub(crate) fn evaluate_black_box( memory.write_slice(memory.read_ref(result.pointer), &[x.into(), y.into()]); Ok(()) } + BlackBoxOp::VariableBaseScalarMul { point_x, point_y, scalar_low, scalar_high, result } => { + let point_x = memory.read(*point_x).try_into().unwrap(); + let point_y = memory.read(*point_y).try_into().unwrap(); + let scalar_low = memory.read(*scalar_low).try_into().unwrap(); + let scalar_high = memory.read(*scalar_high).try_into().unwrap(); + let (out_point_x, out_point_y) = + solver.variable_base_scalar_mul(&point_x, &point_y, &scalar_low, &scalar_high)?; + memory.write_slice( + memory.read_ref(result.pointer), + &[out_point_x.into(), out_point_y.into()], + ); + Ok(()) + } BlackBoxOp::EmbeddedCurveAdd { input1_x, input1_y, input2_x, input2_y, result } => { let input1_x = memory.read(*input1_x).try_into().unwrap(); let input1_y = memory.read(*input1_y).try_into().unwrap(); @@ -289,6 +302,7 @@ fn black_box_function_from_op(op: &BlackBoxOp) -> BlackBoxFunc { BlackBoxOp::PedersenCommitment { .. } => BlackBoxFunc::PedersenCommitment, BlackBoxOp::PedersenHash { .. } => BlackBoxFunc::PedersenHash, BlackBoxOp::FixedBaseScalarMul { .. } => BlackBoxFunc::FixedBaseScalarMul, + BlackBoxOp::VariableBaseScalarMul { .. } => BlackBoxFunc::VariableBaseScalarMul, BlackBoxOp::EmbeddedCurveAdd { .. } => BlackBoxFunc::EmbeddedCurveAdd, BlackBoxOp::BigIntAdd { .. } => BlackBoxFunc::BigIntAdd, BlackBoxOp::BigIntSub { .. } => BlackBoxFunc::BigIntSub, diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs index ee047903743d..210e56b2ecba 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs @@ -201,7 +201,26 @@ pub(crate) fn convert_black_box_call( }); } else { unreachable!( - "ICE: FixedBaseScalarMul expects one register argument and one array result" + "ICE: FixedBaseScalarMul expects two register arguments and one array result" + ) + } + } + BlackBoxFunc::VariableBaseScalarMul => { + if let ( + [BrilligVariable::SingleAddr(point_x), BrilligVariable::SingleAddr(point_y), BrilligVariable::SingleAddr(scalar_low), BrilligVariable::SingleAddr(scalar_high)], + [BrilligVariable::BrilligArray(result_array)], + ) = (function_arguments, function_results) + { + brillig_context.black_box_op_instruction(BlackBoxOp::VariableBaseScalarMul { + point_x: point_x.address, + point_y: point_y.address, + scalar_low: scalar_low.address, + scalar_high: scalar_high.address, + result: result_array.to_heap_array(), + }); + } else { + unreachable!( + "ICE: VariableBaseScalarMul expects four register arguments and one array result" ) } } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir.rs index ded41e02bdae..b4ed59de59d9 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir.rs @@ -175,6 +175,16 @@ pub(crate) mod tests { Ok((4_u128.into(), 5_u128.into())) } + fn variable_base_scalar_mul( + &self, + _point_x: &FieldElement, + _point_y: &FieldElement, + _scalar_low: &FieldElement, + _scalar_high: &FieldElement, + ) -> Result<(FieldElement, FieldElement), BlackBoxResolutionError> { + Ok((7_u128.into(), 8_u128.into())) + } + fn ec_add( &self, _input1_x: &FieldElement, diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs index 5601bbde8772..8b00939b3a71 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs @@ -324,6 +324,23 @@ impl DebugShow { result ); } + BlackBoxOp::VariableBaseScalarMul { + point_x, + point_y, + scalar_low, + scalar_high, + result, + } => { + debug_println!( + self.enable_debug_trace, + " VARIABLE_BASE_SCALAR_MUL ({} {}) ({} {}) -> {}", + point_x, + point_y, + scalar_low, + scalar_high, + result + ); + } BlackBoxOp::EmbeddedCurveAdd { input1_x, input1_y, input2_x, input2_y, result } => { debug_println!( self.enable_debug_trace, diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs index c084ba37fee6..2f4f4f9f6cc7 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs @@ -283,6 +283,13 @@ impl GeneratedAcir { high: inputs[1][0], outputs: (outputs[0], outputs[1]), }, + BlackBoxFunc::VariableBaseScalarMul => BlackBoxFuncCall::VariableBaseScalarMul { + point_x: inputs[0][0], + point_y: inputs[1][0], + scalar_low: inputs[2][0], + scalar_high: inputs[3][0], + outputs: (outputs[0], outputs[1]), + }, BlackBoxFunc::EmbeddedCurveAdd => BlackBoxFuncCall::EmbeddedCurveAdd { input1_x: inputs[0][0], input1_y: inputs[1][0], @@ -669,6 +676,10 @@ fn black_box_func_expected_input_size(name: BlackBoxFunc) -> Option { // is the low and high limbs of the scalar BlackBoxFunc::FixedBaseScalarMul => Some(2), + // Inputs for variable based scalar multiplication are the x and y coordinates of the base point and low + // and high limbs of the scalar + BlackBoxFunc::VariableBaseScalarMul => Some(4), + // Recursive aggregation has a variable number of inputs BlackBoxFunc::RecursiveAggregation => None, @@ -723,7 +734,9 @@ fn black_box_expected_output_size(name: BlackBoxFunc) -> Option { // Output of operations over the embedded curve // will be 2 field elements representing the point. - BlackBoxFunc::FixedBaseScalarMul | BlackBoxFunc::EmbeddedCurveAdd => Some(2), + BlackBoxFunc::FixedBaseScalarMul + | BlackBoxFunc::VariableBaseScalarMul + | BlackBoxFunc::EmbeddedCurveAdd => Some(2), // Big integer operations return a big integer BlackBoxFunc::BigIntAdd diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs index 1187ea8cb07c..a8365ffef39b 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs @@ -453,6 +453,7 @@ fn simplify_black_box_func( } BlackBoxFunc::FixedBaseScalarMul + | BlackBoxFunc::VariableBaseScalarMul | BlackBoxFunc::SchnorrVerify | BlackBoxFunc::PedersenCommitment | BlackBoxFunc::PedersenHash diff --git a/noir/noir-repo/docs/docs/noir/standard_library/cryptographic_primitives/scalar.mdx b/noir/noir-repo/docs/docs/noir/standard_library/cryptographic_primitives/scalar.mdx index c2946b2b73b3..b835236a03e4 100644 --- a/noir/noir-repo/docs/docs/noir/standard_library/cryptographic_primitives/scalar.mdx +++ b/noir/noir-repo/docs/docs/noir/standard_library/cryptographic_primitives/scalar.mdx @@ -1,6 +1,6 @@ --- title: Scalar multiplication -description: See how you can perform scalar multiplications over a fixed base in Noir +description: See how you can perform scalar multiplications over a fixed and variable bases in Noir keywords: [cryptographic primitives, Noir project, scalar multiplication] sidebar_position: 1 --- @@ -9,17 +9,35 @@ import BlackBoxInfo from '@site/src/components/Notes/_blackbox.mdx'; ## scalar_mul::fixed_base_embedded_curve -Performs scalar multiplication over the embedded curve whose coordinates are defined by the -configured noir field. For the BN254 scalar field, this is BabyJubJub or Grumpkin. +Performs scalar multiplication of a fixed base/generator over the embedded curve whose coordinates are defined +by the configured noir field. For the BN254 scalar field, this is BabyJubJub or Grumpkin. Suffixes `_low` and +`_high` denote low and high limbs of the input scalar. #include_code fixed_base_embedded_curve noir_stdlib/src/scalar_mul.nr rust example ```rust -fn main(x : Field) { - let scal = std::scalar_mul::fixed_base_embedded_curve(x); - println(scal); +fn main(scalar_low: Field, scalar_high: Field) { + let point = std::scalar_mul::fixed_base_embedded_curve(scalar_low, scalar_high); + println(point); +} +``` + +## scalar_mul::variable_base_embedded_curve + +Performs scalar multiplication of a variable base/input point over the embedded curve whose coordinates are defined +by the configured noir field. For the BN254 scalar field, this is BabyJubJub or Grumpkin. Suffixes `_low` and +`_high` denote low and high limbs of the input scalar. + +#include_code variable_base_embedded_curve noir_stdlib/src/scalar_mul.nr rust + +example + +```rust +fn main(point_x: Field, point_y: Field, scalar_low: Field, scalar_high: Field) { + let resulting_point = std::scalar_mul::fixed_base_embedded_curve(point_x, point_y, scalar_low, scalar_high); + println(resulting_point); } ``` diff --git a/noir/noir-repo/noir_stdlib/src/grumpkin_scalar_mul.nr b/noir/noir-repo/noir_stdlib/src/grumpkin_scalar_mul.nr index 06d30d623321..c1195073ef6d 100644 --- a/noir/noir-repo/noir_stdlib/src/grumpkin_scalar_mul.nr +++ b/noir/noir-repo/noir_stdlib/src/grumpkin_scalar_mul.nr @@ -1,7 +1,6 @@ use crate::grumpkin_scalar::GrumpkinScalar; -use crate::scalar_mul::fixed_base_embedded_curve; +use crate::scalar_mul::{fixed_base_embedded_curve, variable_base_embedded_curve}; pub fn grumpkin_fixed_base(scalar: GrumpkinScalar) -> [Field; 2] { - // TODO: this should use both the low and high limbs to do the scalar multiplication fixed_base_embedded_curve(scalar.low, scalar.high) -} +} \ No newline at end of file diff --git a/noir/noir-repo/noir_stdlib/src/scalar_mul.nr b/noir/noir-repo/noir_stdlib/src/scalar_mul.nr index eee7aac39f2e..457b7b7791c8 100644 --- a/noir/noir-repo/noir_stdlib/src/scalar_mul.nr +++ b/noir/noir-repo/noir_stdlib/src/scalar_mul.nr @@ -1,5 +1,6 @@ use crate::ops::Add; +// TODO(https://github.com/noir-lang/noir/issues/4931) struct EmbeddedCurvePoint { x: Field, y: Field, @@ -26,12 +27,30 @@ impl Add for EmbeddedCurvePoint { #[foreign(fixed_base_scalar_mul)] // docs:start:fixed_base_embedded_curve pub fn fixed_base_embedded_curve( - low: Field, - high: Field + low: Field, // low limb of the scalar + high: Field // high limb of the scalar ) -> [Field; 2] // docs:end:fixed_base_embedded_curve {} +// Computes a variable base scalar multiplication over the embedded curve. +// For bn254, We have Grumpkin and Baby JubJub. +// For bls12-381, we have JubJub and Bandersnatch. +// +// The embedded curve being used is decided by the +// underlying proof system. +// TODO(https://github.com/noir-lang/noir/issues/4931): use a point struct instead of two fields +#[foreign(variable_base_scalar_mul)] +// docs:start:variable_base_embedded_curve +pub fn variable_base_embedded_curve( + point_x: Field, // x coordinate of a point to multiply the scalar with + point_y: Field, // y coordinate of a point to multiply the scalar with + scalar_low: Field, // low limb of the scalar + scalar_high: Field // high limb of the scalar +) -> [Field; 2] +// docs:end:variable_base_embedded_curve +{} + // This is a hack as returning an `EmbeddedCurvePoint` from a foreign function in brillig returns a [BrilligVariable::SingleAddr; 2] rather than BrilligVariable::BrilligArray // as is defined in the brillig bytecode format. This is a workaround which allows us to fix this without modifying the serialization format. fn embedded_curve_add(point1: EmbeddedCurvePoint, point2: EmbeddedCurvePoint) -> EmbeddedCurvePoint { diff --git a/noir/noir-repo/test_programs/execution_success/scalar_mul/Nargo.toml b/noir/noir-repo/test_programs/execution_success/fixed_base_scalar_mul/Nargo.toml similarity index 63% rename from noir/noir-repo/test_programs/execution_success/scalar_mul/Nargo.toml rename to noir/noir-repo/test_programs/execution_success/fixed_base_scalar_mul/Nargo.toml index 926114ec3745..a8e45c9b5ade 100644 --- a/noir/noir-repo/test_programs/execution_success/scalar_mul/Nargo.toml +++ b/noir/noir-repo/test_programs/execution_success/fixed_base_scalar_mul/Nargo.toml @@ -1,5 +1,5 @@ [package] -name = "scalar_mul" +name = "fixed_base_scalar_mul" type = "bin" authors = [""] diff --git a/noir/noir-repo/test_programs/execution_success/scalar_mul/Prover.toml b/noir/noir-repo/test_programs/execution_success/fixed_base_scalar_mul/Prover.toml similarity index 100% rename from noir/noir-repo/test_programs/execution_success/scalar_mul/Prover.toml rename to noir/noir-repo/test_programs/execution_success/fixed_base_scalar_mul/Prover.toml diff --git a/noir/noir-repo/test_programs/execution_success/scalar_mul/src/main.nr b/noir/noir-repo/test_programs/execution_success/fixed_base_scalar_mul/src/main.nr similarity index 100% rename from noir/noir-repo/test_programs/execution_success/scalar_mul/src/main.nr rename to noir/noir-repo/test_programs/execution_success/fixed_base_scalar_mul/src/main.nr diff --git a/noir/noir-repo/test_programs/execution_success/variable_base_scalar_mul/Nargo.toml b/noir/noir-repo/test_programs/execution_success/variable_base_scalar_mul/Nargo.toml new file mode 100644 index 000000000000..66712ab503cb --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/variable_base_scalar_mul/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "variable_base_scalar_mul" +type = "bin" +authors = [""] + +[dependencies] diff --git a/noir/noir-repo/test_programs/execution_success/variable_base_scalar_mul/Prover.toml b/noir/noir-repo/test_programs/execution_success/variable_base_scalar_mul/Prover.toml new file mode 100644 index 000000000000..51d6fc9b96c5 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/variable_base_scalar_mul/Prover.toml @@ -0,0 +1,4 @@ +point_x = "0x0000000000000000000000000000000000000000000000000000000000000001" +point_y = "0x0000000000000002cf135e7506a45d632d270d45f1181294833fc48d823f272c" +scalar_low = "0x0000000000000000000000000000000000000000000000000000000000000003" +scalar_high = "0x0000000000000000000000000000000000000000000000000000000000000000" diff --git a/noir/noir-repo/test_programs/execution_success/variable_base_scalar_mul/src/main.nr b/noir/noir-repo/test_programs/execution_success/variable_base_scalar_mul/src/main.nr new file mode 100644 index 000000000000..4914ad017771 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/variable_base_scalar_mul/src/main.nr @@ -0,0 +1,33 @@ +use dep::std; + +fn main(point_x: pub Field, point_y: pub Field, scalar_low: pub Field, scalar_high: pub Field) { + // We multiply the point by 3 and check it matches result out of embedded_curve_add func + let res = std::scalar_mul::variable_base_embedded_curve(point_x, point_y, scalar_low, scalar_high); + + let point = std::scalar_mul::EmbeddedCurvePoint { x: point_x, y: point_y }; + + let double = point.double(); + let triple = point + double; + + assert(triple.x == res[0]); + assert(triple.y == res[1]); + + // We test that brillig gives us the same result + let brillig_res = get_brillig_result(point_x, point_y, scalar_low, scalar_high); + assert(res[0] == brillig_res[0]); + assert(res[1] == brillig_res[1]); + + // Multiplying the point by 1 should return the same point + let res = std::scalar_mul::variable_base_embedded_curve(point_x, point_y, 1, 0); + assert(point_x == res[0]); + assert(point_y == res[1]); +} + +unconstrained fn get_brillig_result( + point_x: Field, + point_y: Field, + scalar_low: Field, + scalar_high: Field +) -> [Field; 2] { + std::scalar_mul::variable_base_embedded_curve(point_x, point_y, scalar_low, scalar_high) +} diff --git a/noir/noir-repo/tooling/lsp/src/solver.rs b/noir/noir-repo/tooling/lsp/src/solver.rs index 0fea9b16b54a..b47c30af5f68 100644 --- a/noir/noir-repo/tooling/lsp/src/solver.rs +++ b/noir/noir-repo/tooling/lsp/src/solver.rs @@ -32,6 +32,16 @@ impl BlackBoxFunctionSolver for WrapperSolver { self.0.fixed_base_scalar_mul(low, high) } + fn variable_base_scalar_mul( + &self, + point_x: &acvm::FieldElement, + point_y: &acvm::FieldElement, + scalar_low: &acvm::FieldElement, + scalar_high: &acvm::FieldElement, + ) -> Result<(acvm::FieldElement, acvm::FieldElement), acvm::BlackBoxResolutionError> { + self.0.variable_base_scalar_mul(point_x, point_y, scalar_low, scalar_high) + } + fn pedersen_hash( &self, inputs: &[acvm::FieldElement], diff --git a/yarn-project/end-to-end/src/e2e_state_vars.test.ts b/yarn-project/end-to-end/src/e2e_state_vars.test.ts index 8a6ed6dc23ea..8d63cafea4f9 100644 --- a/yarn-project/end-to-end/src/e2e_state_vars.test.ts +++ b/yarn-project/end-to-end/src/e2e_state_vars.test.ts @@ -22,7 +22,7 @@ describe('e2e_state_vars', () => { beforeAll(async () => { ({ teardown, wallet, pxe } = await setup(2)); contract = await DocsExampleContract.deploy(wallet).send().deployed(); - }, 30_000); + }, 60_000); afterAll(() => teardown()); From 1904fa864ff8c546d4d849436c6ca7a7606fb3d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Bene=C5=A1?= Date: Tue, 30 Apr 2024 15:10:43 +0200 Subject: [PATCH 09/42] feat: computing sym key for incoming ciphertext (#6020) --- .../src/core/libraries/ConstantsGen.sol | 14 ++--- noir-projects/aztec-nr/aztec/src/keys.nr | 1 + .../aztec/src/keys/point_to_symmetric_key.nr | 34 +++++++++++++ noir-projects/aztec-nr/aztec/src/lib.nr | 1 + .../crates/types/src/constants.nr | 1 + .../crates/types/src/grumpkin_point.nr | 11 ++++ .../l1_note_payload/encrypt_buffer.test.ts | 48 ++++++++++------- .../logs/l1_note_payload/encrypt_buffer.ts | 51 +++++++++---------- .../l1_note_payload/l1_note_payload.test.ts | 8 +-- .../logs/l1_note_payload/l1_note_payload.ts | 23 +++------ .../logs/l1_note_payload/tagged_note.test.ts | 8 +-- .../src/logs/l1_note_payload/tagged_note.ts | 11 ++-- yarn-project/circuits.js/src/constants.gen.ts | 13 +++-- .../foundation/src/testing/test_data.ts | 2 +- .../src/note_processor/note_processor.test.ts | 4 +- .../pxe/src/note_processor/note_processor.ts | 4 +- .../src/client/client_execution_context.ts | 6 +-- .../simulator/src/client/simulator.ts | 4 -- 18 files changed, 145 insertions(+), 99 deletions(-) create mode 100644 noir-projects/aztec-nr/aztec/src/keys.nr create mode 100644 noir-projects/aztec-nr/aztec/src/keys/point_to_symmetric_key.nr diff --git a/l1-contracts/src/core/libraries/ConstantsGen.sol b/l1-contracts/src/core/libraries/ConstantsGen.sol index b9e22c3eb709..d6a267143885 100644 --- a/l1-contracts/src/core/libraries/ConstantsGen.sol +++ b/l1-contracts/src/core/libraries/ConstantsGen.sol @@ -118,8 +118,10 @@ library Constants { uint256 internal constant NULLIFIER_KEY_VALIDATION_REQUEST_CONTEXT_LENGTH = 4; uint256 internal constant PARTIAL_STATE_REFERENCE_LENGTH = 6; uint256 internal constant READ_REQUEST_LENGTH = 2; + uint256 internal constant NOTE_HASH_LENGTH = 2; + uint256 internal constant NOTE_HASH_CONTEXT_LENGTH = 3; + uint256 internal constant NULLIFIER_LENGTH = 3; uint256 internal constant SIDE_EFFECT_LENGTH = 2; - uint256 internal constant SIDE_EFFECT_LINKED_TO_NOTE_HASH_LENGTH = 3; uint256 internal constant STATE_REFERENCE_LENGTH = APPEND_ONLY_TREE_SNAPSHOT_LENGTH + PARTIAL_STATE_REFERENCE_LENGTH; uint256 internal constant TX_CONTEXT_LENGTH = 2 + GAS_SETTINGS_LENGTH; @@ -130,9 +132,9 @@ library Constants { + MAX_BLOCK_NUMBER_LENGTH + (SIDE_EFFECT_LENGTH * MAX_NOTE_HASH_READ_REQUESTS_PER_CALL) + (READ_REQUEST_LENGTH * MAX_NULLIFIER_READ_REQUESTS_PER_CALL) + (NULLIFIER_KEY_VALIDATION_REQUEST_LENGTH * MAX_NULLIFIER_KEY_VALIDATION_REQUESTS_PER_CALL) - + (SIDE_EFFECT_LENGTH * MAX_NEW_NOTE_HASHES_PER_CALL) - + (SIDE_EFFECT_LINKED_TO_NOTE_HASH_LENGTH * MAX_NEW_NULLIFIERS_PER_CALL) - + MAX_PRIVATE_CALL_STACK_LENGTH_PER_CALL + MAX_PUBLIC_CALL_STACK_LENGTH_PER_CALL + + (NOTE_HASH_LENGTH * MAX_NEW_NOTE_HASHES_PER_CALL) + + (NULLIFIER_LENGTH * MAX_NEW_NULLIFIERS_PER_CALL) + MAX_PRIVATE_CALL_STACK_LENGTH_PER_CALL + + MAX_PUBLIC_CALL_STACK_LENGTH_PER_CALL + (L2_TO_L1_MESSAGE_LENGTH * MAX_NEW_L2_TO_L1_MSGS_PER_CALL) + 2 + (SIDE_EFFECT_LENGTH * MAX_ENCRYPTED_LOGS_PER_CALL) + (SIDE_EFFECT_LENGTH * MAX_UNENCRYPTED_LOGS_PER_CALL) + 2 + HEADER_LENGTH + TX_CONTEXT_LENGTH; @@ -141,8 +143,8 @@ library Constants { + (READ_REQUEST_LENGTH * MAX_NULLIFIER_NON_EXISTENT_READ_REQUESTS_PER_CALL) + (CONTRACT_STORAGE_UPDATE_REQUEST_LENGTH * MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_CALL) + (CONTRACT_STORAGE_READ_LENGTH * MAX_PUBLIC_DATA_READS_PER_CALL) - + MAX_PUBLIC_CALL_STACK_LENGTH_PER_CALL + (SIDE_EFFECT_LENGTH * MAX_NEW_NOTE_HASHES_PER_CALL) - + (SIDE_EFFECT_LINKED_TO_NOTE_HASH_LENGTH * MAX_NEW_NULLIFIERS_PER_CALL) + + MAX_PUBLIC_CALL_STACK_LENGTH_PER_CALL + (NOTE_HASH_LENGTH * MAX_NEW_NOTE_HASHES_PER_CALL) + + (NULLIFIER_LENGTH * MAX_NEW_NULLIFIERS_PER_CALL) + (L2_TO_L1_MESSAGE_LENGTH * MAX_NEW_L2_TO_L1_MSGS_PER_CALL) + 2 + (SIDE_EFFECT_LENGTH * MAX_UNENCRYPTED_LOGS_PER_CALL) + 1 + HEADER_LENGTH + GLOBAL_VARIABLES_LENGTH + AZTEC_ADDRESS_LENGTH /* revert_code */ + 1 + 2 * GAS_LENGTH /* transaction_fee */ diff --git a/noir-projects/aztec-nr/aztec/src/keys.nr b/noir-projects/aztec-nr/aztec/src/keys.nr new file mode 100644 index 000000000000..20d48201aca0 --- /dev/null +++ b/noir-projects/aztec-nr/aztec/src/keys.nr @@ -0,0 +1 @@ +mod point_to_symmetric_key; \ No newline at end of file diff --git a/noir-projects/aztec-nr/aztec/src/keys/point_to_symmetric_key.nr b/noir-projects/aztec-nr/aztec/src/keys/point_to_symmetric_key.nr new file mode 100644 index 000000000000..b708d00e8bc1 --- /dev/null +++ b/noir-projects/aztec-nr/aztec/src/keys/point_to_symmetric_key.nr @@ -0,0 +1,34 @@ +use dep::protocol_types::{constants::GENERATOR_INDEX__SYMMETRIC_KEY, grumpkin_point::GrumpkinPoint, utils::arr_copy_slice}; +use dep::std::{hash::sha256, grumpkin_scalar::GrumpkinScalar, scalar_mul::variable_base_embedded_curve}; + +// TODO(#5726): This function is called deriveAESSecret in TS. I don't like point_to_symmetric_key name much since +// point is not the only input of the function. Unify naming with TS once we have a better name. +pub fn point_to_symmetric_key(secret: GrumpkinScalar, point: GrumpkinPoint) -> [u8; 32] { + let shared_secret_fields = variable_base_embedded_curve(point.x, point.y, secret.low, secret.high); + // TODO(https://github.com/AztecProtocol/aztec-packages/issues/6061): make the func return Point struct directly + let shared_secret = GrumpkinPoint::new(shared_secret_fields[0], shared_secret_fields[1]); + let mut shared_secret_bytes_with_separator = [0 as u8; 65]; + shared_secret_bytes_with_separator = arr_copy_slice(shared_secret.to_be_bytes(), shared_secret_bytes_with_separator, 0); + shared_secret_bytes_with_separator[64] = GENERATOR_INDEX__SYMMETRIC_KEY; + sha256(shared_secret_bytes_with_separator) +} + +#[test] +fn check_point_to_symmetric_key() { + // Value taken from "derive shared secret" test in encrypt_buffer.test.ts + let secret = GrumpkinScalar::new( + 0x00000000000000000000000000000000649e7ca01d9de27b21624098b897babd, + 0x0000000000000000000000000000000023b3127c127b1f29a7adff5cccf8fb06 + ); + let point = GrumpkinPoint::new( + 0x2688431c705a5ff3e6c6f2573c9e3ba1c1026d2251d0dbbf2d810aa53fd1d186, + 0x1e96887b117afca01c00468264f4f80b5bb16d94c1808a448595f115556e5c8e + ); + + let key = point_to_symmetric_key(secret, point); + // The following value gets updated when running encrypt_buffer.test.ts with AZTEC_GENERATE_TEST_DATA=1 + let expected_key = [ + 198, 74, 242, 51, 177, 36, 183, 8, 2, 246, 197, 138, 59, 166, 86, 96, 155, 50, 186, 34, 242, 3, 208, 144, 161, 64, 69, 165, 70, 57, 226, 139 + ]; + assert_eq(key, expected_key); +} diff --git a/noir-projects/aztec-nr/aztec/src/lib.nr b/noir-projects/aztec-nr/aztec/src/lib.nr index 81390fb52eb9..67a3c8c55b47 100644 --- a/noir-projects/aztec-nr/aztec/src/lib.nr +++ b/noir-projects/aztec-nr/aztec/src/lib.nr @@ -3,6 +3,7 @@ mod deploy; mod hash; mod history; mod initializer; +mod keys; mod log; mod messaging; mod note; diff --git a/noir-projects/noir-protocol-circuits/crates/types/src/constants.nr b/noir-projects/noir-protocol-circuits/crates/types/src/constants.nr index 61924a73f8e8..37c7bc84848c 100644 --- a/noir-projects/noir-protocol-circuits/crates/types/src/constants.nr +++ b/noir-projects/noir-protocol-circuits/crates/types/src/constants.nr @@ -255,3 +255,4 @@ global GENERATOR_INDEX__PUBLIC_KEYS_HASH = 51; global GENERATOR_INDEX__NOTE_NULLIFIER = 52; global GENERATOR_INDEX__INNER_NOTE_HASH = 53; global GENERATOR_INDEX__NOTE_CONTENT_HASH = 54; +global GENERATOR_INDEX__SYMMETRIC_KEY: u8 = 55; diff --git a/noir-projects/noir-protocol-circuits/crates/types/src/grumpkin_point.nr b/noir-projects/noir-protocol-circuits/crates/types/src/grumpkin_point.nr index ec334d78b50f..467a022947b3 100644 --- a/noir-projects/noir-protocol-circuits/crates/types/src/grumpkin_point.nr +++ b/noir-projects/noir-protocol-circuits/crates/types/src/grumpkin_point.nr @@ -51,4 +51,15 @@ impl GrumpkinPoint { assert(self.x == 0); assert(self.y == 0); } + + pub fn to_be_bytes(self: Self) -> [u8; 64] { + let mut result = [0 as u8; 64]; + let x_bytes = self.x.to_be_bytes(32); + let y_bytes = self.y.to_be_bytes(32); + for i in 0..32 { + result[i] = x_bytes[i]; + result[i + 32] = y_bytes[i]; + } + result + } } diff --git a/yarn-project/circuit-types/src/logs/l1_note_payload/encrypt_buffer.test.ts b/yarn-project/circuit-types/src/logs/l1_note_payload/encrypt_buffer.test.ts index b34baa4fb01f..d8a0c4b9998f 100644 --- a/yarn-project/circuit-types/src/logs/l1_note_payload/encrypt_buffer.test.ts +++ b/yarn-project/circuit-types/src/logs/l1_note_payload/encrypt_buffer.test.ts @@ -1,6 +1,7 @@ -import { GrumpkinScalar } from '@aztec/circuits.js'; +import { Fq, GrumpkinScalar } from '@aztec/circuits.js'; import { Grumpkin } from '@aztec/circuits.js/barretenberg'; import { randomBytes } from '@aztec/foundation/crypto'; +import { updateInlineTestData } from '@aztec/foundation/testing'; import { decryptBuffer, deriveAESSecret, encryptBuffer } from './encrypt_buffer.js'; @@ -12,38 +13,51 @@ describe('encrypt buffer', () => { }); it('derive shared secret', () => { - const ownerPrivKey = GrumpkinScalar.random(); - const ownerPubKey = grumpkin.mul(Grumpkin.generator, ownerPrivKey); - const ephPrivKey = GrumpkinScalar.random(); - const ephPubKey = grumpkin.mul(Grumpkin.generator, ephPrivKey); + // The following 2 are arbitrary fixed values - fixed in order to test a match with Noir + const ownerSecretKey: GrumpkinScalar = new Fq(0x23b3127c127b1f29a7adff5cccf8fb06649e7ca01d9de27b21624098b897babdn); + const ephSecretKey: GrumpkinScalar = new Fq(0x1fdd0dd8c99b21af8e00d2d130bdc263b36dadcbea84ac5ec9293a0660deca01n); - const secretBySender = deriveAESSecret(ownerPubKey, ephPrivKey, grumpkin); - const secretByReceiver = deriveAESSecret(ephPubKey, ownerPrivKey, grumpkin); + const ownerPubKey = grumpkin.mul(Grumpkin.generator, ownerSecretKey); + const ephPubKey = grumpkin.mul(Grumpkin.generator, ephSecretKey); + + const secretBySender = deriveAESSecret(ephSecretKey, ownerPubKey); + const secretByReceiver = deriveAESSecret(ownerSecretKey, ephPubKey); expect(secretBySender.toString('hex')).toEqual(secretByReceiver.toString('hex')); + + const byteArrayString = `[${secretBySender + .toString('hex') + .match(/.{1,2}/g)! + .map(byte => parseInt(byte, 16))}]`; + // Run with AZTEC_GENERATE_TEST_DATA=1 to update noir test data + updateInlineTestData( + 'noir-projects/aztec-nr/aztec/src/keys/point_to_symmetric_key.nr', + 'expected_key', + byteArrayString, + ); }); it('convert to and from encrypted buffer', () => { const data = randomBytes(253); - const ownerPrivKey = GrumpkinScalar.random(); - const ownerPubKey = grumpkin.mul(Grumpkin.generator, ownerPrivKey); - const ephPrivKey = GrumpkinScalar.random(); - const encrypted = encryptBuffer(data, ownerPubKey, ephPrivKey, grumpkin); - const decrypted = decryptBuffer(encrypted, ownerPrivKey, grumpkin); + const ownerSecretKey = GrumpkinScalar.random(); + const ownerPubKey = grumpkin.mul(Grumpkin.generator, ownerSecretKey); + const ephSecretKey = GrumpkinScalar.random(); + const encrypted = encryptBuffer(data, ephSecretKey, ownerPubKey); + const decrypted = decryptBuffer(encrypted, ownerSecretKey); expect(decrypted).not.toBeUndefined(); expect(decrypted).toEqual(data); }); it('decrypting gibberish returns undefined', () => { const data = randomBytes(253); - const ownerPrivKey = GrumpkinScalar.random(); - const ephPrivKey = GrumpkinScalar.random(); - const ownerPubKey = grumpkin.mul(Grumpkin.generator, ownerPrivKey); - const encrypted = encryptBuffer(data, ownerPubKey, ephPrivKey, grumpkin); + const ownerSecretKey = GrumpkinScalar.random(); + const ephSecretKey = GrumpkinScalar.random(); + const ownerPubKey = grumpkin.mul(Grumpkin.generator, ownerSecretKey); + const encrypted = encryptBuffer(data, ephSecretKey, ownerPubKey); // Introduce gibberish. const gibberish = Buffer.concat([randomBytes(8), encrypted.subarray(8)]); - const decrypted = decryptBuffer(gibberish, ownerPrivKey, grumpkin); + const decrypted = decryptBuffer(gibberish, ownerSecretKey); expect(decrypted).toBeUndefined(); }); }); diff --git a/yarn-project/circuit-types/src/logs/l1_note_payload/encrypt_buffer.ts b/yarn-project/circuit-types/src/logs/l1_note_payload/encrypt_buffer.ts index 314ef0d6c28f..28262ff3ba62 100644 --- a/yarn-project/circuit-types/src/logs/l1_note_payload/encrypt_buffer.ts +++ b/yarn-project/circuit-types/src/logs/l1_note_payload/encrypt_buffer.ts @@ -1,5 +1,5 @@ -import { type GrumpkinPrivateKey, type PublicKey } from '@aztec/circuits.js'; -import { type Grumpkin } from '@aztec/circuits.js/barretenberg'; +import { GeneratorIndex, type GrumpkinPrivateKey, type PublicKey } from '@aztec/circuits.js'; +import { Grumpkin } from '@aztec/circuits.js/barretenberg'; import { sha256 } from '@aztec/foundation/crypto'; import { Point } from '@aztec/foundation/fields'; import { numToUInt8 } from '@aztec/foundation/serialize'; @@ -12,14 +12,16 @@ import { createCipheriv, createDecipheriv } from 'browserify-cipher'; * the shared secret. The shared secret is then hashed using SHA-256 to produce the final * AES secret key. * - * @param ecdhPubKey - The ECDH public key represented as a PublicKey object. - * @param ecdhPrivKey - The ECDH private key represented as a Buffer object. - * @param grumpkin - The curve to use for curve operations. - * @returns A Buffer containing the derived AES secret key. + * @param secretKey - The secret key used to derive shared secret. + * @param publicKey - The public key used to derive shared secret. + * @returns A derived AES secret key. + * TODO(#5726): This function is called point_to_symmetric_key in Noir. I don't like that name much since point is not + * the only input of the function. Unify naming once we have a better name. */ -export function deriveAESSecret(ecdhPubKey: PublicKey, ecdhPrivKey: GrumpkinPrivateKey, curve: Grumpkin): Buffer { - const sharedSecret = curve.mul(ecdhPubKey, ecdhPrivKey); - const secretBuffer = Buffer.concat([sharedSecret.toBuffer(), numToUInt8(1)]); +export function deriveAESSecret(secretKey: GrumpkinPrivateKey, publicKey: PublicKey): Buffer { + const curve = new Grumpkin(); + const sharedSecret = curve.mul(publicKey, secretKey); + const secretBuffer = Buffer.concat([sharedSecret.toBuffer(), numToUInt8(GeneratorIndex.SYMMETRIC_KEY)]); const hash = sha256(secretBuffer); return hash; } @@ -31,40 +33,37 @@ export function deriveAESSecret(ecdhPubKey: PublicKey, ecdhPrivKey: GrumpkinPriv * with the provided curve instance for elliptic curve operations. * * @param data - The data buffer to be encrypted. - * @param ownerPubKey - The owner's public key as a PublicKey instance. - * @param ephPrivKey - The ephemeral private key as a Buffer instance. - * @param curve - The curve instance used for elliptic curve operations. + * @param ephSecretKey - The ephemeral secret key.. + * @param incomingViewingPublicKey - The note owner's incoming viewing public key. * @returns A Buffer containing the encrypted data and the ephemeral public key. */ export function encryptBuffer( data: Buffer, - ownerPubKey: PublicKey, - ephPrivKey: GrumpkinPrivateKey, - curve: Grumpkin, + ephSecretKey: GrumpkinPrivateKey, + incomingViewingPublicKey: PublicKey, ): Buffer { - const aesSecret = deriveAESSecret(ownerPubKey, ephPrivKey, curve); + const aesSecret = deriveAESSecret(ephSecretKey, incomingViewingPublicKey); const aesKey = aesSecret.subarray(0, 16); const iv = aesSecret.subarray(16, 32); const cipher = createCipheriv('aes-128-cbc', aesKey, iv); const plaintext = Buffer.concat([iv.subarray(0, 8), data]); - const ephPubKey = curve.mul(curve.generator(), ephPrivKey); + const curve = new Grumpkin(); + const ephPubKey = curve.mul(curve.generator(), ephSecretKey); + return Buffer.concat([cipher.update(plaintext), cipher.final(), ephPubKey.toBuffer()]); } /** - * Decrypts the given encrypted data buffer using the owner's private key and a Grumpkin curve. - * Extracts the ephemeral public key from the input data, derives the AES secret using - * the owner's private key, and decrypts the plaintext. - * If the decryption is successful, returns the decrypted plaintext, otherwise returns undefined. - * + * Decrypts the given encrypted data buffer using the provided secret key. * @param data - The encrypted data buffer to be decrypted. - * @param ownerPrivKey - The private key of the owner used for decryption. - * @param curve - The curve object used in the decryption process. + * @param incomingViewingSecretKey - The secret key used for decryption. * @returns The decrypted plaintext as a Buffer or undefined if decryption fails. */ -export function decryptBuffer(data: Buffer, ownerPrivKey: GrumpkinPrivateKey, curve: Grumpkin): Buffer | undefined { +export function decryptBuffer(data: Buffer, incomingViewingSecretKey: GrumpkinPrivateKey): Buffer | undefined { + // Extract the ephemeral public key from the end of the data const ephPubKey = Point.fromBuffer(data.subarray(-64)); - const aesSecret = deriveAESSecret(ephPubKey, ownerPrivKey, curve); + // Derive the AES secret key using the secret key and the ephemeral public key + const aesSecret = deriveAESSecret(incomingViewingSecretKey, ephPubKey); const aesKey = aesSecret.subarray(0, 16); const iv = aesSecret.subarray(16, 32); const cipher = createDecipheriv('aes-128-cbc', aesKey, iv); diff --git a/yarn-project/circuit-types/src/logs/l1_note_payload/l1_note_payload.test.ts b/yarn-project/circuit-types/src/logs/l1_note_payload/l1_note_payload.test.ts index d5a909fe7793..288e46db36d2 100644 --- a/yarn-project/circuit-types/src/logs/l1_note_payload/l1_note_payload.test.ts +++ b/yarn-project/circuit-types/src/logs/l1_note_payload/l1_note_payload.test.ts @@ -20,8 +20,8 @@ describe('L1 Note Payload', () => { const payload = L1NotePayload.random(); const ownerPrivKey = GrumpkinScalar.random(); const ownerPubKey = grumpkin.mul(Grumpkin.generator, ownerPrivKey); - const encrypted = payload.toEncryptedBuffer(ownerPubKey, grumpkin); - const decrypted = L1NotePayload.fromEncryptedBuffer(encrypted, ownerPrivKey, grumpkin); + const encrypted = payload.toEncryptedBuffer(ownerPubKey); + const decrypted = L1NotePayload.fromEncryptedBuffer(encrypted, ownerPrivKey); expect(decrypted).not.toBeUndefined(); expect(decrypted).toEqual(payload); }); @@ -29,9 +29,9 @@ describe('L1 Note Payload', () => { it('return undefined if unable to decrypt the encrypted buffer', () => { const payload = L1NotePayload.random(); const ownerPubKey = Point.random(); - const encrypted = payload.toEncryptedBuffer(ownerPubKey, grumpkin); + const encrypted = payload.toEncryptedBuffer(ownerPubKey); const randomPrivKey = GrumpkinScalar.random(); - const decrypted = L1NotePayload.fromEncryptedBuffer(encrypted, randomPrivKey, grumpkin); + const decrypted = L1NotePayload.fromEncryptedBuffer(encrypted, randomPrivKey); expect(decrypted).toBeUndefined(); }); }); diff --git a/yarn-project/circuit-types/src/logs/l1_note_payload/l1_note_payload.ts b/yarn-project/circuit-types/src/logs/l1_note_payload/l1_note_payload.ts index 8b34dd37b8f3..463512782a70 100644 --- a/yarn-project/circuit-types/src/logs/l1_note_payload/l1_note_payload.ts +++ b/yarn-project/circuit-types/src/logs/l1_note_payload/l1_note_payload.ts @@ -1,5 +1,4 @@ import { AztecAddress, type GrumpkinPrivateKey, type PublicKey } from '@aztec/circuits.js'; -import { type Grumpkin } from '@aztec/circuits.js/barretenberg'; import { Fr, GrumpkinScalar } from '@aztec/foundation/fields'; import { BufferReader, serializeToBuffer } from '@aztec/foundation/serialize'; @@ -56,28 +55,22 @@ export class L1NotePayload { /** * Encrypt the L1NotePayload object using the owner's public key and the ephemeral private key. - * @param ownerPubKey - Public key of the owner of the L1NotePayload object. - * @param curve - The curve instance to use. + * @param incomingViewingPubKey - Public key of the owner of the L1NotePayload object. * @returns The encrypted L1NotePayload object. */ - public toEncryptedBuffer(ownerPubKey: PublicKey, curve: Grumpkin): Buffer { - const ephPrivKey: GrumpkinPrivateKey = GrumpkinScalar.random(); - return encryptBuffer(this.toBuffer(), ownerPubKey, ephPrivKey, curve); + public toEncryptedBuffer(incomingViewingPubKey: PublicKey): Buffer { + const ephSecretKey: GrumpkinPrivateKey = GrumpkinScalar.random(); + return encryptBuffer(this.toBuffer(), ephSecretKey, incomingViewingPubKey); } /** - * Decrypts the L1NotePayload object using the owner's private key. + * Decrypts the L1NotePayload object using the owner's incoming viewing secret key. * @param data - Encrypted L1NotePayload object. - * @param ownerPrivKey - Private key of the owner of the L1NotePayload object. - * @param curve - The curve instance to use. + * @param incomingViewingSecretKey - Incoming viewing secret key of the owner of the L1NotePayload object. * @returns Instance of L1NotePayload if the decryption was successful, undefined otherwise. */ - static fromEncryptedBuffer( - data: Buffer, - ownerPrivKey: GrumpkinPrivateKey, - curve: Grumpkin, - ): L1NotePayload | undefined { - const buf = decryptBuffer(data, ownerPrivKey, curve); + static fromEncryptedBuffer(data: Buffer, incomingViewingSecretKey: GrumpkinPrivateKey): L1NotePayload | undefined { + const buf = decryptBuffer(data, incomingViewingSecretKey); if (!buf) { return; } diff --git a/yarn-project/circuit-types/src/logs/l1_note_payload/tagged_note.test.ts b/yarn-project/circuit-types/src/logs/l1_note_payload/tagged_note.test.ts index 2eb74ba247a4..bbd171f37020 100644 --- a/yarn-project/circuit-types/src/logs/l1_note_payload/tagged_note.test.ts +++ b/yarn-project/circuit-types/src/logs/l1_note_payload/tagged_note.test.ts @@ -23,8 +23,8 @@ describe('L1 Note Payload', () => { const taggedNote = new TaggedNote(payload); const ownerPrivKey = GrumpkinScalar.random(); const ownerPubKey = grumpkin.mul(Grumpkin.generator, ownerPrivKey); - const encrypted = taggedNote.toEncryptedBuffer(ownerPubKey, grumpkin); - const decrypted = TaggedNote.fromEncryptedBuffer(encrypted, ownerPrivKey, grumpkin); + const encrypted = taggedNote.toEncryptedBuffer(ownerPubKey); + const decrypted = TaggedNote.fromEncryptedBuffer(encrypted, ownerPrivKey); expect(decrypted).not.toBeUndefined(); expect(decrypted?.notePayload).toEqual(payload); }); @@ -33,9 +33,9 @@ describe('L1 Note Payload', () => { const payload = L1NotePayload.random(); const taggedNote = new TaggedNote(payload); const ownerPubKey = Point.random(); - const encrypted = taggedNote.toEncryptedBuffer(ownerPubKey, grumpkin); + const encrypted = taggedNote.toEncryptedBuffer(ownerPubKey); const randomPrivKey = GrumpkinScalar.random(); - const decrypted = TaggedNote.fromEncryptedBuffer(encrypted, randomPrivKey, grumpkin); + const decrypted = TaggedNote.fromEncryptedBuffer(encrypted, randomPrivKey); expect(decrypted).toBeUndefined(); }); }); diff --git a/yarn-project/circuit-types/src/logs/l1_note_payload/tagged_note.ts b/yarn-project/circuit-types/src/logs/l1_note_payload/tagged_note.ts index ddc362e46281..4e698e382eb2 100644 --- a/yarn-project/circuit-types/src/logs/l1_note_payload/tagged_note.ts +++ b/yarn-project/circuit-types/src/logs/l1_note_payload/tagged_note.ts @@ -1,5 +1,4 @@ import { type GrumpkinPrivateKey, type PublicKey } from '@aztec/circuits.js'; -import { type Grumpkin } from '@aztec/circuits.js/barretenberg'; import { Fr } from '@aztec/foundation/fields'; import { BufferReader, serializeToBuffer } from '@aztec/foundation/serialize'; @@ -37,11 +36,10 @@ export class TaggedNote { /** * Encrypt the L1NotePayload object using the owner's public key and the ephemeral private key, then attach the tag. * @param ownerPubKey - Public key of the owner of the TaggedNote object. - * @param curve - The curve instance to use. * @returns The encrypted TaggedNote object. */ - public toEncryptedBuffer(ownerPubKey: PublicKey, curve: Grumpkin): Buffer { - const encryptedL1NotePayload = this.notePayload.toEncryptedBuffer(ownerPubKey, curve); + public toEncryptedBuffer(ownerPubKey: PublicKey): Buffer { + const encryptedL1NotePayload = this.notePayload.toEncryptedBuffer(ownerPubKey); return serializeToBuffer(this.tag, encryptedL1NotePayload); } @@ -49,16 +47,15 @@ export class TaggedNote { * Decrypts the L1NotePayload object using the owner's private key. * @param data - Encrypted TaggedNote object. * @param ownerPrivKey - Private key of the owner of the TaggedNote object. - * @param curve - The curve instance to use. * @returns Instance of TaggedNote if the decryption was successful, undefined otherwise. */ - static fromEncryptedBuffer(data: Buffer, ownerPrivKey: GrumpkinPrivateKey, curve: Grumpkin): TaggedNote | undefined { + static fromEncryptedBuffer(data: Buffer, ownerPrivKey: GrumpkinPrivateKey): TaggedNote | undefined { const reader = BufferReader.asReader(data); const tag = Fr.fromBuffer(reader); const encryptedL1NotePayload = reader.readToEnd(); - const payload = L1NotePayload.fromEncryptedBuffer(encryptedL1NotePayload, ownerPrivKey, curve); + const payload = L1NotePayload.fromEncryptedBuffer(encryptedL1NotePayload, ownerPrivKey); if (!payload) { return; } diff --git a/yarn-project/circuits.js/src/constants.gen.ts b/yarn-project/circuits.js/src/constants.gen.ts index 1b98ef98a39c..0e2f73c46d25 100644 --- a/yarn-project/circuits.js/src/constants.gen.ts +++ b/yarn-project/circuits.js/src/constants.gen.ts @@ -101,8 +101,10 @@ export const NULLIFIER_KEY_VALIDATION_REQUEST_LENGTH = 3; export const NULLIFIER_KEY_VALIDATION_REQUEST_CONTEXT_LENGTH = 4; export const PARTIAL_STATE_REFERENCE_LENGTH = 6; export const READ_REQUEST_LENGTH = 2; +export const NOTE_HASH_LENGTH = 2; +export const NOTE_HASH_CONTEXT_LENGTH = 3; +export const NULLIFIER_LENGTH = 3; export const SIDE_EFFECT_LENGTH = 2; -export const SIDE_EFFECT_LINKED_TO_NOTE_HASH_LENGTH = 3; export const STATE_REFERENCE_LENGTH = APPEND_ONLY_TREE_SNAPSHOT_LENGTH + PARTIAL_STATE_REFERENCE_LENGTH; export const TX_CONTEXT_LENGTH = 2 + GAS_SETTINGS_LENGTH; export const TX_REQUEST_LENGTH = 2 + TX_CONTEXT_LENGTH + FUNCTION_DATA_LENGTH; @@ -115,8 +117,8 @@ export const PRIVATE_CIRCUIT_PUBLIC_INPUTS_LENGTH = SIDE_EFFECT_LENGTH * MAX_NOTE_HASH_READ_REQUESTS_PER_CALL + READ_REQUEST_LENGTH * MAX_NULLIFIER_READ_REQUESTS_PER_CALL + NULLIFIER_KEY_VALIDATION_REQUEST_LENGTH * MAX_NULLIFIER_KEY_VALIDATION_REQUESTS_PER_CALL + - SIDE_EFFECT_LENGTH * MAX_NEW_NOTE_HASHES_PER_CALL + - SIDE_EFFECT_LINKED_TO_NOTE_HASH_LENGTH * MAX_NEW_NULLIFIERS_PER_CALL + + NOTE_HASH_LENGTH * MAX_NEW_NOTE_HASHES_PER_CALL + + NULLIFIER_LENGTH * MAX_NEW_NULLIFIERS_PER_CALL + MAX_PRIVATE_CALL_STACK_LENGTH_PER_CALL + MAX_PUBLIC_CALL_STACK_LENGTH_PER_CALL + L2_TO_L1_MESSAGE_LENGTH * MAX_NEW_L2_TO_L1_MSGS_PER_CALL + @@ -134,8 +136,8 @@ export const PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH = CONTRACT_STORAGE_UPDATE_REQUEST_LENGTH * MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_CALL + CONTRACT_STORAGE_READ_LENGTH * MAX_PUBLIC_DATA_READS_PER_CALL + MAX_PUBLIC_CALL_STACK_LENGTH_PER_CALL + - SIDE_EFFECT_LENGTH * MAX_NEW_NOTE_HASHES_PER_CALL + - SIDE_EFFECT_LINKED_TO_NOTE_HASH_LENGTH * MAX_NEW_NULLIFIERS_PER_CALL + + NOTE_HASH_LENGTH * MAX_NEW_NOTE_HASHES_PER_CALL + + NULLIFIER_LENGTH * MAX_NEW_NULLIFIERS_PER_CALL + L2_TO_L1_MESSAGE_LENGTH * MAX_NEW_L2_TO_L1_MSGS_PER_CALL + 2 + SIDE_EFFECT_LENGTH * MAX_UNENCRYPTED_LOGS_PER_CALL + @@ -210,4 +212,5 @@ export enum GeneratorIndex { NOTE_NULLIFIER = 52, INNER_NOTE_HASH = 53, NOTE_CONTENT_HASH = 54, + SYMMETRIC_KEY = 55, } diff --git a/yarn-project/foundation/src/testing/test_data.ts b/yarn-project/foundation/src/testing/test_data.ts index a92794e64f6a..dba6932cd5cc 100644 --- a/yarn-project/foundation/src/testing/test_data.ts +++ b/yarn-project/foundation/src/testing/test_data.ts @@ -66,7 +66,7 @@ export function updateInlineTestData(targetFileFromRepoRoot: string, itemName: s const logger = createConsoleLogger('aztec:testing:test_data'); const targetFile = getPathToFile(targetFileFromRepoRoot); const contents = readFileSync(targetFile, 'utf8').toString(); - const regex = new RegExp(`let ${itemName} = .*;`, 'g'); + const regex = new RegExp(`let ${itemName} = [\\s\\S]*?;`, 'g'); if (!regex.exec(contents)) { throw new Error(`Test data marker for ${itemName} not found in ${targetFile}`); } diff --git a/yarn-project/pxe/src/note_processor/note_processor.test.ts b/yarn-project/pxe/src/note_processor/note_processor.test.ts index b182e5caeb55..1ddd5bd1ce62 100644 --- a/yarn-project/pxe/src/note_processor/note_processor.test.ts +++ b/yarn-project/pxe/src/note_processor/note_processor.test.ts @@ -17,7 +17,6 @@ import { type PublicKey, deriveKeys, } from '@aztec/circuits.js'; -import { Grumpkin } from '@aztec/circuits.js/barretenberg'; import { pedersenHash } from '@aztec/foundation/crypto'; import { Point } from '@aztec/foundation/fields'; import { openTmpStore } from '@aztec/kv-store/utils'; @@ -34,7 +33,6 @@ import { NoteProcessor } from './note_processor.js'; const TXS_PER_BLOCK = 4; describe('Note Processor', () => { - const grumpkin = new Grumpkin(); let database: PxeDatabase; let aztecNode: ReturnType>; let addNotesSpy: any; @@ -72,7 +70,7 @@ describe('Note Processor', () => { ownedL1NotePayloads.push(note.notePayload); } // const encryptedNote = - const log = note.toEncryptedBuffer(publicKey, grumpkin); + const log = note.toEncryptedBuffer(publicKey); // 1 tx containing 1 function invocation containing 1 log logs.push(new EncryptedFunctionL2Logs([new EncryptedL2Log(log)])); } diff --git a/yarn-project/pxe/src/note_processor/note_processor.ts b/yarn-project/pxe/src/note_processor/note_processor.ts index 0950a427598f..3eaa6b3006d5 100644 --- a/yarn-project/pxe/src/note_processor/note_processor.ts +++ b/yarn-project/pxe/src/note_processor/note_processor.ts @@ -8,7 +8,6 @@ import { } from '@aztec/circuit-types'; import { type NoteProcessorStats } from '@aztec/circuit-types/stats'; import { INITIAL_L2_BLOCK_NUM, MAX_NEW_NOTE_HASHES_PER_TX, type PublicKey } from '@aztec/circuits.js'; -import { Grumpkin } from '@aztec/circuits.js/barretenberg'; import { type Fr } from '@aztec/foundation/fields'; import { createDebugLogger } from '@aztec/foundation/log'; import { Timer } from '@aztec/foundation/timer'; @@ -99,7 +98,6 @@ export class NoteProcessor { return; } - const curve = new Grumpkin(); const blocksAndNotes: ProcessedData[] = []; // Keep track of notes that we couldn't process because the contract was not found. const deferredNoteDaos: DeferredNoteDao[] = []; @@ -132,7 +130,7 @@ export class NoteProcessor { for (const functionLogs of txFunctionLogs) { for (const log of functionLogs.logs) { this.stats.seen++; - const taggedNote = TaggedNote.fromEncryptedBuffer(log.data, secretKey, curve); + const taggedNote = TaggedNote.fromEncryptedBuffer(log.data, secretKey); if (taggedNote?.notePayload) { const { notePayload: payload } = taggedNote; // We have successfully decrypted the data. diff --git a/yarn-project/simulator/src/client/client_execution_context.ts b/yarn-project/simulator/src/client/client_execution_context.ts index 57a79776a973..22beba44aa13 100644 --- a/yarn-project/simulator/src/client/client_execution_context.ts +++ b/yarn-project/simulator/src/client/client_execution_context.ts @@ -21,7 +21,7 @@ import { type SideEffect, type TxContext, } from '@aztec/circuits.js'; -import { Aes128, type Grumpkin } from '@aztec/circuits.js/barretenberg'; +import { Aes128 } from '@aztec/circuits.js/barretenberg'; import { computePublicDataTreeLeafSlot, computeUniqueNoteHash, siloNoteHash } from '@aztec/circuits.js/hash'; import { type FunctionAbi, type FunctionArtifact, countArgumentsSize } from '@aztec/foundation/abi'; import { type AztecAddress } from '@aztec/foundation/aztec-address'; @@ -77,7 +77,6 @@ export class ClientExecutionContext extends ViewDataOracle { private readonly packedValuesCache: PackedValuesCache, private readonly noteCache: ExecutionNoteCache, db: DBOracle, - private readonly curve: Grumpkin, private node: AztecNode, protected sideEffectCounter: number = 0, log = createDebugLogger('aztec:simulator:client_execution_context'), @@ -345,7 +344,7 @@ export class ClientExecutionContext extends ViewDataOracle { const note = new Note(log); const l1NotePayload = new L1NotePayload(note, contractAddress, storageSlot, noteTypeId); const taggedNote = new TaggedNote(l1NotePayload); - const encryptedNote = taggedNote.toEncryptedBuffer(publicKey, this.curve); + const encryptedNote = taggedNote.toEncryptedBuffer(publicKey); const encryptedLog = new EncryptedL2Log(encryptedNote); this.encryptedLogs.push(encryptedLog); return Fr.fromBuffer(encryptedLog.hash()); @@ -421,7 +420,6 @@ export class ClientExecutionContext extends ViewDataOracle { this.packedValuesCache, this.noteCache, this.db, - this.curve, this.node, sideEffectCounter, ); diff --git a/yarn-project/simulator/src/client/simulator.ts b/yarn-project/simulator/src/client/simulator.ts index a0499982a292..cb767ea4d243 100644 --- a/yarn-project/simulator/src/client/simulator.ts +++ b/yarn-project/simulator/src/client/simulator.ts @@ -1,6 +1,5 @@ import { type AztecNode, type FunctionCall, type Note, type TxExecutionRequest } from '@aztec/circuit-types'; import { CallContext, FunctionData } from '@aztec/circuits.js'; -import { Grumpkin } from '@aztec/circuits.js/barretenberg'; import { type ArrayType, type FunctionArtifactWithDebugMetadata, @@ -79,8 +78,6 @@ export class AcirSimulator { ); } - const curve = new Grumpkin(); - const header = await this.db.getHeader(); // reserve the first side effect for the tx hash (inserted by the private kernel) @@ -104,7 +101,6 @@ export class AcirSimulator { PackedValuesCache.create(request.argsOfCalls), new ExecutionNoteCache(), this.db, - curve, this.node, startSideEffectCounter, ); From 9f5773353aa0261fa07a81704bcadcee513d42c5 Mon Sep 17 00:00:00 2001 From: Innokentii Sennovskii Date: Tue, 30 Apr 2024 15:49:09 +0100 Subject: [PATCH 10/42] feat: Avoiding redundant computation in PG (#5844) This PR reduces PG computation time by removing computation on indices over which accumulated relation values are expected to be zero. This gives us a speedup of 4-5%. The PR also parallelised pertubator root construction. Before: x86_64: ![image](https://github.com/AztecProtocol/aztec-packages/assets/4798775/80247864-1c1c-4e34-8756-a8bd44bdbab2) wasm: ![image](https://github.com/AztecProtocol/aztec-packages/assets/4798775/649dfd97-d65c-48a5-8b8c-02fb3fbb9f47) After: x86_64: ![image](https://github.com/AztecProtocol/aztec-packages/assets/4798775/d453e026-7a88-4646-8094-b7102916a2af) wasm: ![image](https://github.com/AztecProtocol/aztec-packages/assets/4798775/549b89a3-d1dd-4058-84a2-92856180d15d) --- .../cpp/src/barretenberg/flavor/flavor.hpp | 15 +- .../barretenberg/polynomials/univariate.hpp | 214 ++++++++++------ .../protogalaxy/combiner.test.cpp | 44 ++-- .../protogalaxy/protogalaxy.test.cpp | 5 + .../protogalaxy/protogalaxy_prover.hpp | 242 +++++++++++++++--- .../relations/nested_containers.hpp | 24 +- .../barretenberg/relations/relation_types.hpp | 6 + .../goblin_ultra_flavor.hpp | 14 + .../stdlib_circuit_builders/ultra_flavor.hpp | 11 + .../sumcheck/instance/instances.hpp | 7 +- .../barretenberg/vm/generated/avm_flavor.hpp | 7 + 11 files changed, 445 insertions(+), 144 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/flavor/flavor.hpp b/barretenberg/cpp/src/barretenberg/flavor/flavor.hpp index 18eda8ecd0cf..fb30168d58dd 100644 --- a/barretenberg/cpp/src/barretenberg/flavor/flavor.hpp +++ b/barretenberg/cpp/src/barretenberg/flavor/flavor.hpp @@ -251,18 +251,23 @@ template static constexpr size_t compute * @details The size of the outer tuple is equal to the number of relations. Each relation contributes an inner tuple of * univariates whose size is equal to the number of subrelations of the relation. The length of a univariate in an inner * tuple is determined by the corresponding subrelation length and the number of instances to be folded. + * @tparam optimised Enable optimised version with skipping some of the computation */ -template +template static constexpr auto create_protogalaxy_tuple_of_tuples_of_univariates() { if constexpr (Index >= std::tuple_size::value) { return std::tuple<>{}; // Return empty when reach end of the tuple } else { using UnivariateTuple = - typename std::tuple_element_t::template ProtogalaxyTupleOfUnivariatesOverSubrelations; - return std::tuple_cat(std::tuple{}, - create_protogalaxy_tuple_of_tuples_of_univariates()); + std::conditional_t:: + template OptimisedProtogalaxyTupleOfUnivariatesOverSubrelations, + typename std::tuple_element_t:: + template ProtogalaxyTupleOfUnivariatesOverSubrelations>; + return std::tuple_cat( + std::tuple{}, + create_protogalaxy_tuple_of_tuples_of_univariates()); } } diff --git a/barretenberg/cpp/src/barretenberg/polynomials/univariate.hpp b/barretenberg/cpp/src/barretenberg/polynomials/univariate.hpp index aedc13537875..6471ba85b56e 100644 --- a/barretenberg/cpp/src/barretenberg/polynomials/univariate.hpp +++ b/barretenberg/cpp/src/barretenberg/polynomials/univariate.hpp @@ -13,17 +13,22 @@ namespace bb { * of the data in those univariates. We do that by taking a view of those elements and then, as needed, using this to * populate new containers. */ -template class UnivariateView; +template class UnivariateView; /** * @brief A univariate polynomial represented by its values on {domain_start, domain_start + 1,..., domain_end - 1}. For * memory efficiency purposes, we store the evaluations in an array starting from 0 and make the mapping to the right * domain under the hood. + * + * @tparam skip_count Skip computing the values of elements [domain_start+1,..,domain_start+skip_count]. Used for + * optimising computation in protogalaxy. The value at [domain_start] is the value from the accumulator instance, while + * the values in [domain_start+1, ... domain_start + skip_count] in the accumulator should be zero if the original + * instances are correct. */ -template class Univariate { +template class Univariate { public: static constexpr size_t LENGTH = domain_end - domain_start; - using View = UnivariateView; + using View = UnivariateView; using value_type = Fr; // used to get the type of the elements consistently with std::array @@ -40,8 +45,27 @@ template class Univariate Univariate(Univariate&& other) noexcept = default; Univariate& operator=(const Univariate& other) = default; Univariate& operator=(Univariate&& other) noexcept = default; - // Construct constant Univariate from scalar which represents the value that all the points in the domain evaluate - // to + + /** + * @brief Convert from a version with skipped evaluations to one without skipping (with zeroes in previously skipped + * locations) + * + * @return Univariate + */ + Univariate convert() const noexcept + { + Univariate result; + result.evaluations[0] = evaluations[0]; + for (size_t i = 1; i < skip_count + 1; i++) { + result.evaluations[i] = Fr::zero(); + } + for (size_t i = skip_count + 1; i < LENGTH; i++) { + result.evaluations[i] = evaluations[i]; + } + return result; + } + // Construct constant Univariate from scalar which represents the value that all the points in the domain + // evaluate to explicit Univariate(Fr value) : evaluations{} { @@ -50,7 +74,7 @@ template class Univariate } } // Construct Univariate from UnivariateView - explicit Univariate(UnivariateView in) + explicit Univariate(UnivariateView in) : evaluations{} { for (size_t i = 0; i < in.evaluations.size(); ++i) { @@ -77,7 +101,7 @@ template class Univariate static Univariate get_random() { - auto output = Univariate(); + auto output = Univariate(); for (size_t i = 0; i != LENGTH; ++i) { output.value_at(i) = Fr::random_element(); } @@ -86,7 +110,7 @@ template class Univariate static Univariate zero() { - auto output = Univariate(); + auto output = Univariate(); for (size_t i = 0; i != LENGTH; ++i) { output.value_at(i) = Fr::zero(); } @@ -100,21 +124,25 @@ template class Univariate Univariate& operator+=(const Univariate& other) { - for (size_t i = 0; i < LENGTH; ++i) { + evaluations[0] += other.evaluations[0]; + for (size_t i = skip_count + 1; i < LENGTH; ++i) { evaluations[i] += other.evaluations[i]; } return *this; } Univariate& operator-=(const Univariate& other) { - for (size_t i = 0; i < LENGTH; ++i) { + evaluations[0] -= other.evaluations[0]; + for (size_t i = skip_count + 1; i < LENGTH; ++i) { + evaluations[i] -= other.evaluations[i]; } return *this; } Univariate& operator*=(const Univariate& other) { - for (size_t i = 0; i < LENGTH; ++i) { + evaluations[0] *= other.evaluations[0]; + for (size_t i = skip_count + 1; i < LENGTH; ++i) { evaluations[i] *= other.evaluations[i]; } return *this; @@ -135,8 +163,12 @@ template class Univariate Univariate operator-() const { Univariate res(*this); + size_t i = 0; for (auto& eval : res.evaluations) { - eval = -eval; + if (i == 0 || i >= (skip_count + 1)) { + eval = -eval; + } + i++; } return res; } @@ -151,23 +183,39 @@ template class Univariate // Operations between Univariate and scalar Univariate& operator+=(const Fr& scalar) { + size_t i = 0; for (auto& eval : evaluations) { - eval += scalar; + if (i == 0 || i >= (skip_count + 1)) { + eval += scalar; + } + i++; } return *this; } Univariate& operator-=(const Fr& scalar) { + size_t i = 0; for (auto& eval : evaluations) { - eval -= scalar; + // If skip count is zero, will be enabled on every line, otherwise don't compute for [domain_start+1,.., + // domain_start + skip_count] + if (i == 0 || i >= (skip_count + 1)) { + eval -= scalar; + } + i++; } return *this; } Univariate& operator*=(const Fr& scalar) { + size_t i = 0; for (auto& eval : evaluations) { - eval *= scalar; + // If skip count is zero, will be enabled on every line, otherwise don't compute for [domain_start+1,.., + // domain_start + skip_count] + if (i == 0 || i >= (skip_count + 1)) { + eval *= scalar; + } + i++; } return *this; } @@ -194,45 +242,48 @@ template class Univariate } // Operations between Univariate and UnivariateView - Univariate& operator+=(const UnivariateView& view) + Univariate& operator+=(const UnivariateView& view) { - for (size_t i = 0; i < LENGTH; ++i) { + evaluations[0] += view.evaluations[0]; + for (size_t i = skip_count + 1; i < LENGTH; ++i) { evaluations[i] += view.evaluations[i]; } return *this; } - Univariate& operator-=(const UnivariateView& view) + Univariate& operator-=(const UnivariateView& view) { - for (size_t i = 0; i < LENGTH; ++i) { + evaluations[0] -= view.evaluations[0]; + for (size_t i = skip_count + 1; i < LENGTH; ++i) { evaluations[i] -= view.evaluations[i]; } return *this; } - Univariate& operator*=(const UnivariateView& view) + Univariate& operator*=(const UnivariateView& view) { - for (size_t i = 0; i < LENGTH; ++i) { + evaluations[0] *= view.evaluations[0]; + for (size_t i = skip_count + 1; i < LENGTH; ++i) { evaluations[i] *= view.evaluations[i]; } return *this; } - Univariate operator+(const UnivariateView& view) const + Univariate operator+(const UnivariateView& view) const { Univariate res(*this); res += view; return res; } - Univariate operator-(const UnivariateView& view) const + Univariate operator-(const UnivariateView& view) const { Univariate res(*this); res -= view; return res; } - Univariate operator*(const UnivariateView& view) const + Univariate operator*(const UnivariateView& view) const { Univariate res(*this); res *= view; @@ -256,39 +307,42 @@ template class Univariate } /** - * @brief Given a univariate f represented by {f(domain_start), ..., f(domain_end - 1)}, compute the evaluations - * {f(domain_end),..., f(extended_domain_end -1)} and return the Univariate represented by {f(domain_start),..., - * f(extended_domain_end -1)} + * @brief Given a univariate f represented by {f(domain_start), ..., f(domain_end - 1)}, compute the + * evaluations {f(domain_end),..., f(extended_domain_end -1)} and return the Univariate represented by + * {f(domain_start),..., f(extended_domain_end -1)} * - * @details Write v_i = f(x_i) on a the domain {x_{domain_start}, ..., x_{domain_end-1}}. To efficiently compute the - * needed values of f, we use the barycentric formula + * @details Write v_i = f(x_i) on a the domain {x_{domain_start}, ..., x_{domain_end-1}}. To efficiently + * compute the needed values of f, we use the barycentric formula * - f(x) = B(x) Σ_{i=domain_start}^{domain_end-1} v_i / (d_i*(x-x_i)) * where * - B(x) = Π_{i=domain_start}^{domain_end-1} (x-x_i) - * - d_i = Π_{j ∈ {domain_start, ..., domain_end-1}, j≠i} (x_i-x_j) for i ∈ {domain_start, ..., domain_end-1} + * - d_i = Π_{j ∈ {domain_start, ..., domain_end-1}, j≠i} (x_i-x_j) for i ∈ {domain_start, ..., + * domain_end-1} * - * When the domain size is two, extending f = v0(1-X) + v1X to a new value involves just one addition and a - * subtraction: setting Δ = v1-v0, the values of f(X) are f(0)=v0, f(1)= v0 + Δ, v2 = f(1) + Δ, v3 = f(2) + Δ... + * When the domain size is two, extending f = v0(1-X) + v1X to a new value involves just one addition + * and a subtraction: setting Δ = v1-v0, the values of f(X) are f(0)=v0, f(1)= v0 + Δ, v2 = f(1) + Δ, v3 + * = f(2) + Δ... * */ - template Univariate extend_to() const + template + Univariate extend_to() const { const size_t EXTENDED_LENGTH = EXTENDED_DOMAIN_END - domain_start; using Data = BarycentricData; static_assert(EXTENDED_LENGTH >= LENGTH); - Univariate result; + Univariate result; std::copy(evaluations.begin(), evaluations.end(), result.evaluations.begin()); static constexpr Fr inverse_two = Fr(2).invert(); + static_assert(NUM_SKIPPED_INDICES < LENGTH); if constexpr (LENGTH == 2) { Fr delta = value_at(1) - value_at(0); static_assert(EXTENDED_LENGTH != 0); for (size_t idx = domain_end - 1; idx < EXTENDED_DOMAIN_END - 1; idx++) { result.value_at(idx + 1) = result.value_at(idx) + delta; } - return result; } else if constexpr (LENGTH == 3) { // Based off https://hackmd.io/@aztec-network/SyR45cmOq?type=view // The technique used here is the same as the length == 3 case below. @@ -304,7 +358,6 @@ template class Univariate result.value_at(idx + 1) = result.value_at(idx) + extra; extra += a2; } - return result; } else if constexpr (LENGTH == 4) { static constexpr Fr inverse_six = Fr(6).invert(); // computed at compile time for efficiency @@ -315,8 +368,8 @@ template class Univariate // a*1 + b*1 + c*1 + d = f(1) // a*2^3 + b*2^2 + c*2 + d = f(2) // a*3^3 + b*3^2 + c*3 + d = f(3) - // These equations can be rewritten as a matrix equation M * [a, b, c, d] = [f(0), f(1), f(2), f(3)], where - // M is: + // These equations can be rewritten as a matrix equation M * [a, b, c, d] = [f(0), f(1), f(2), + // f(3)], where M is: // 0, 0, 0, 1 // 1, 1, 1, 1 // 2^3, 2^2, 2, 1 @@ -326,9 +379,9 @@ template class Univariate // 1, -5/2, 2, -1/2 // -11/6, 3, -3/2, 1/3 // 1, 0, 0, 0 - // To compute these values, we can multiply everything by 6 and multiply by inverse_six at the end for each - // coefficient The resulting computation here does 18 field adds, 6 subtracts, 3 muls to compute a, b, c, - // and d. + // To compute these values, we can multiply everything by 6 and multiply by inverse_six at the + // end for each coefficient The resulting computation here does 18 field adds, 6 subtracts, 3 + // muls to compute a, b, c, and d. Fr zero_times_3 = value_at(0) + value_at(0) + value_at(0); Fr zero_times_6 = zero_times_3 + zero_times_3; Fr zero_times_12 = zero_times_6 + zero_times_6; @@ -368,7 +421,6 @@ template class Univariate linear_term += three_a_plus_two_b; } - return result; } else { for (size_t k = domain_end; k != EXTENDED_DOMAIN_END; ++k) { result.value_at(k) = 0; @@ -381,8 +433,8 @@ template class Univariate // scale the sum by the the value of of B(x) result.value_at(k) *= Data::full_numerator_values[k]; } - return result; } + return result; } /** @@ -399,8 +451,8 @@ template class Univariate full_numerator_value *= u - i; } - // build set of domain size-many denominator inverses 1/(d_i*(x_k - x_j)). will multiply against each of - // these (rather than to divide by something) for each barycentric evaluation + // build set of domain size-many denominator inverses 1/(d_i*(x_k - x_j)). will multiply against + // each of these (rather than to divide by something) for each barycentric evaluation std::array denominator_inverses; for (size_t i = 0; i != LENGTH; ++i) { Fr inv = Data::lagrange_denominators[i]; @@ -443,7 +495,7 @@ inline void write(B& it, Univariate const& univari write(it, univariate.evaluations); } -template class UnivariateView { +template class UnivariateView { public: static constexpr size_t LENGTH = domain_end - domain_start; std::span evaluations; @@ -453,77 +505,84 @@ template class Univariate const Fr& value_at(size_t i) const { return evaluations[i]; }; template - explicit UnivariateView(const Univariate& univariate_in) + explicit UnivariateView(const Univariate& univariate_in) : evaluations(std::span(univariate_in.evaluations.data(), LENGTH)){}; - Univariate operator+(const UnivariateView& other) const + Univariate operator+(const UnivariateView& other) const { - Univariate res(*this); + Univariate res(*this); res += other; return res; } - Univariate operator-(const UnivariateView& other) const + Univariate operator-(const UnivariateView& other) const { - Univariate res(*this); + Univariate res(*this); res -= other; return res; } - Univariate operator-() const + Univariate operator-() const { - Univariate res(*this); + Univariate res(*this); + size_t i = 0; for (auto& eval : res.evaluations) { - eval = -eval; + if (i == 0 || i >= (skip_count + 1)) { + eval = -eval; + } + i++; } return res; } - Univariate operator*(const UnivariateView& other) const + Univariate operator*(const UnivariateView& other) const { - Univariate res(*this); + Univariate res(*this); res *= other; return res; } - Univariate operator*(const Univariate& other) const + Univariate operator*( + const Univariate& other) const { - Univariate res(*this); + Univariate res(*this); res *= other; return res; } - Univariate operator+(const Univariate& other) const + Univariate operator+( + const Univariate& other) const { - Univariate res(*this); + Univariate res(*this); res += other; return res; } - Univariate operator+(const Fr& other) const + Univariate operator+(const Fr& other) const { - Univariate res(*this); + Univariate res(*this); res += other; return res; } - Univariate operator-(const Fr& other) const + Univariate operator-(const Fr& other) const { - Univariate res(*this); + Univariate res(*this); res -= other; return res; } - Univariate operator*(const Fr& other) const + Univariate operator*(const Fr& other) const { - Univariate res(*this); + Univariate res(*this); res *= other; return res; } - Univariate operator-(const Univariate& other) const + Univariate operator-( + const Univariate& other) const { - Univariate res(*this); + Univariate res(*this); res -= other; return res; } @@ -546,8 +605,8 @@ template class Univariate }; /** - * @brief Create a sub-array of `elements` at the indices given in the template pack `Is`, converting them to the new - * type T. + * @brief Create a sub-array of `elements` at the indices given in the template pack `Is`, converting them + * to the new type T. * * @tparam T type to convert to * @tparam U type to convert from @@ -555,8 +614,8 @@ template class Univariate * @tparam Is list of indices we want in the returned array. When the second argument is called with * `std::make_index_sequence`, these will be `0, 1, ..., N-1`. * @param elements array to convert from - * @return std::array result array s.t. result[i] = T(elements[Is[i]]). By default, Is[i] = i when - * called with `std::make_index_sequence`. + * @return std::array result array s.t. result[i] = T(elements[Is[i]]). By default, Is[i] + * = i when called with `std::make_index_sequence`. */ template std::array array_to_array_aux(const std::array& elements, std::index_sequence) @@ -568,11 +627,12 @@ std::array array_to_array_aux(const std::array& elements * @brief Given an std::array, returns an std::array, by calling the (explicit) constructor T(U). * * @details https://stackoverflow.com/a/32175958 - * The main use case is to convert an array of `Univariate` into `UnivariateView`. The main use case would be to let - * Sumcheck decide the required degree of the relation evaluation, rather than hardcoding it inside the relation. The - * `_aux` version could also be used to create an array of only the polynomials required by the relation, and it could - * help us implement the optimization where we extend each edge only up to the maximum degree that is required over all - * relations (for example, `L_LAST` only needs degree 3). + * The main use case is to convert an array of `Univariate` into `UnivariateView`. The main use case would + * be to let Sumcheck decide the required degree of the relation evaluation, rather than hardcoding it + * inside the relation. The + * `_aux` version could also be used to create an array of only the polynomials required by the relation, + * and it could help us implement the optimization where we extend each edge only up to the maximum degree + * that is required over all relations (for example, `L_LAST` only needs degree 3). * * @tparam T Output type * @tparam U Input type (deduced from `elements`) diff --git a/barretenberg/cpp/src/barretenberg/protogalaxy/combiner.test.cpp b/barretenberg/cpp/src/barretenberg/protogalaxy/combiner.test.cpp index 72a6fa53233c..4ff7f81cb513 100644 --- a/barretenberg/cpp/src/barretenberg/protogalaxy/combiner.test.cpp +++ b/barretenberg/cpp/src/barretenberg/protogalaxy/combiner.test.cpp @@ -44,6 +44,12 @@ TEST(Protogalaxy, CombinerOn2Instances) auto prover_polynomials = get_sequential_prover_polynomials( /*log_circuit_size=*/1, idx * 128); restrict_to_standard_arithmetic_relation(prover_polynomials); + // This ensures that the combiner accumulator for second instance = 0 + // The value is computed by generating the python script values, computing the resulting accumulator and + // taking the value at index 1 + if (idx == NUM_INSTANCES - 1) { + prover_polynomials.q_c[0] -= 13644570; + } instance->proving_key.polynomials = std::move(prover_polynomials); instance->proving_key.circuit_size = 2; instance_data[idx] = instance; @@ -52,22 +58,22 @@ TEST(Protogalaxy, CombinerOn2Instances) ProverInstances instances{ instance_data }; instances.alphas.fill(bb::Univariate(FF(0))); // focus on the arithmetic relation only auto pow_polynomial = PowPolynomial(std::vector{ 2 }); - auto result = prover.compute_combiner(instances, pow_polynomial); - auto expected_result = Univariate(std::array{ - 87706, - 13644570, - 76451738, - 226257946, - static_cast(500811930), - static_cast(937862426), - static_cast(1575158170), - static_cast(2450447898), - static_cast(3601480346), - static_cast(5066004250), - static_cast(6881768346), - static_cast(9086521370), - }); + auto result = prover.compute_combiner(instances, pow_polynomial); + auto optimised_result = prover.compute_combiner(instances, pow_polynomial); + auto expected_result = Univariate(std::array{ 87706, + 0, + 0x02ee2966, + 0x0b0bd2cc, + 0x00001a98fc32, + 0x000033d5a598, + 0x00005901cefe, + 0x00008c5d7864, + 0x0000d028a1ca, + 0x000126a34b30UL, + 0x0001920d7496UL, + 0x000214a71dfcUL }); EXPECT_EQ(result, expected_result); + EXPECT_EQ(optimised_result, expected_result); } else { std::vector> instance_data(NUM_INSTANCES); ProtoGalaxyProver prover; @@ -130,11 +136,13 @@ TEST(Protogalaxy, CombinerOn2Instances) 0 0 0 0 0 0 0 0 0 6 18 36 60 90 */ auto pow_polynomial = PowPolynomial(std::vector{ 2 }); - auto result = prover.compute_combiner(instances, pow_polynomial); + auto result = prover.compute_combiner(instances, pow_polynomial); + auto optimised_result = prover.compute_combiner(instances, pow_polynomial); auto expected_result = Univariate(std::array{ 0, 0, 12, 36, 72, 120, 180, 252, 336, 432, 540, 660 }); EXPECT_EQ(result, expected_result); + EXPECT_EQ(optimised_result, expected_result); } }; run_test(true); @@ -181,11 +189,13 @@ TEST(Protogalaxy, CombinerOn4Instances) zero_all_selectors(instances[3]->proving_key.polynomials); auto pow_polynomial = PowPolynomial(std::vector{ 2 }); - auto result = prover.compute_combiner(instances, pow_polynomial); + auto result = prover.compute_combiner(instances, pow_polynomial); + auto optimised_result = prover.compute_combiner(instances, pow_polynomial); std::array zeroes; std::fill(zeroes.begin(), zeroes.end(), 0); auto expected_result = Univariate(zeroes); EXPECT_EQ(result, expected_result); + EXPECT_EQ(optimised_result, expected_result); }; run_test(); }; diff --git a/barretenberg/cpp/src/barretenberg/protogalaxy/protogalaxy.test.cpp b/barretenberg/cpp/src/barretenberg/protogalaxy/protogalaxy.test.cpp index 0b51c91f57b4..3148c54cd406 100644 --- a/barretenberg/cpp/src/barretenberg/protogalaxy/protogalaxy.test.cpp +++ b/barretenberg/cpp/src/barretenberg/protogalaxy/protogalaxy.test.cpp @@ -279,6 +279,11 @@ template class ProtoGalaxyTests : public testing::Test { bb::Univariate expected_eta{ { 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21 } }; EXPECT_EQ(instances.relation_parameters.eta, expected_eta); + // Optimised relation parameters are the same, we just don't compute any values for non-used indices when + // deriving values from them + for (size_t i = 0; i < 11; i++) { + EXPECT_EQ(instances.optimised_relation_parameters.eta.evaluations[i], expected_eta.evaluations[i]); + } } /** diff --git a/barretenberg/cpp/src/barretenberg/protogalaxy/protogalaxy_prover.hpp b/barretenberg/cpp/src/barretenberg/protogalaxy/protogalaxy_prover.hpp index d9cffcca9c70..c03af2e5333e 100644 --- a/barretenberg/cpp/src/barretenberg/protogalaxy/protogalaxy_prover.hpp +++ b/barretenberg/cpp/src/barretenberg/protogalaxy/protogalaxy_prover.hpp @@ -49,15 +49,26 @@ template class ProtoGalaxyProver_ { // The length of ExtendedUnivariate is the largest length (==max_relation_degree + 1) of a univariate polynomial // obtained by composing a relation with folded instance + relation parameters . using ExtendedUnivariate = Univariate; + // Same as ExtendedUnivariate, but uses optimised univariates which skip redundant computation in optimistic cases + // (when we know that the evaluation of all relations is 0 on a particular index, for example) + using OptimisedExtendedUnivariate = + Univariate; // Represents the total length of the combiner univariate, obtained by combining the already folded relations with // the folded relation batching challenge. using ExtendedUnivariateWithRandomization = Univariate; using ExtendedUnivariates = typename Flavor::template ProverUnivariates; + using OptimisedExtendedUnivariates = + typename Flavor::template OptimisedProverUnivariates; using TupleOfTuplesOfUnivariates = typename Flavor::template ProtogalaxyTupleOfTuplesOfUnivariates; + using OptimisedTupleOfTuplesOfUnivariates = + typename Flavor::template OptimisedProtogalaxyTupleOfTuplesOfUnivariates; using RelationEvaluations = typename Flavor::TupleOfArraysOfValues; static constexpr size_t NUM_SUBRELATIONS = ProverInstances::NUM_SUBRELATIONS; @@ -209,14 +220,20 @@ template class ProtoGalaxyProver_ { auto prev_level_width = prev_level_coeffs.size(); // we need degree + 1 terms to represent the intermediate polynomials std::vector> level_coeffs(prev_level_width >> 1, std::vector(degree + 1, 0)); - for (size_t node = 0; node < prev_level_width; node += 2) { - auto parent = node >> 1; - std::copy(prev_level_coeffs[node].begin(), prev_level_coeffs[node].end(), level_coeffs[parent].begin()); - for (size_t d = 0; d < degree; d++) { - level_coeffs[parent][d] += prev_level_coeffs[node + 1][d] * betas[level]; - level_coeffs[parent][d + 1] += prev_level_coeffs[node + 1][d] * deltas[level]; - } - } + run_loop_in_parallel( + prev_level_width >> 1, + [&](size_t start, size_t end) { + for (size_t node = start << 1; node < end << 1; node += 2) { + auto parent = node >> 1; + std::copy( + prev_level_coeffs[node].begin(), prev_level_coeffs[node].end(), level_coeffs[parent].begin()); + for (size_t d = 0; d < degree; d++) { + level_coeffs[parent][d] += prev_level_coeffs[node + 1][d] * betas[level]; + level_coeffs[parent][d + 1] += prev_level_coeffs[node + 1][d] * deltas[level]; + } + } + }, + /*no_multhreading_if_less_or_equal=*/8); return construct_coefficients_tree(betas, deltas, level_coeffs, level + 1); } @@ -236,11 +253,16 @@ template class ProtoGalaxyProver_ { { auto width = full_honk_evaluations.size(); std::vector> first_level_coeffs(width >> 1, std::vector(2, 0)); - for (size_t node = 0; node < width; node += 2) { - auto parent = node >> 1; - first_level_coeffs[parent][0] = full_honk_evaluations[node] + full_honk_evaluations[node + 1] * betas[0]; - first_level_coeffs[parent][1] = full_honk_evaluations[node + 1] * deltas[0]; - } + run_loop_in_parallel(width >> 1, [&](size_t start, size_t end) { + // Run loop in parallel can divide the domain in such way that the indices are odd, which we can't tolerate + // here, so first we divide the width by two, enable parallelism and then reconstruct even start and end + for (size_t node = start << 1; node < end << 1; node += 2) { + auto parent = node >> 1; + first_level_coeffs[parent][0] = + full_honk_evaluations[node] + full_honk_evaluations[node + 1] * betas[0]; + first_level_coeffs[parent][1] = full_honk_evaluations[node + 1] * deltas[0]; + } + }); return construct_coefficients_tree(betas, deltas, first_level_coeffs); } @@ -262,26 +284,38 @@ template class ProtoGalaxyProver_ { return Polynomial(coeffs); } + OptimisedTupleOfTuplesOfUnivariates optimised_univariate_accumulators; TupleOfTuplesOfUnivariates univariate_accumulators; /** * @brief Prepare a univariate polynomial for relation execution in one step of the main loop in folded instance * construction. - * @details For a fixed prover polynomial index, extract that polynomial from each instance in Instances. From each - * polynomial, extract the value at row_idx. Use these values to create a univariate polynomial, and then extend - * (i.e., compute additional evaluations at adjacent domain values) as needed. + * @details For a fixed prover polynomial index, extract that polynomial from each instance in Instances. From + *each polynomial, extract the value at row_idx. Use these values to create a univariate polynomial, and then + *extend (i.e., compute additional evaluations at adjacent domain values) as needed. * @todo TODO(https://github.com/AztecProtocol/barretenberg/issues/751) Optimize memory + * + * */ - void extend_univariates(ExtendedUnivariates& extended_univariates, - const ProverInstances& instances, - const size_t row_idx) + + template + void extend_univariates( + std::conditional_t& extended_univariates, + const ProverInstances& instances, + const size_t row_idx) { - auto base_univariates = instances.row_to_univariates(row_idx); + auto base_univariates = instances.template row_to_univariates(row_idx); for (auto [extended_univariate, base_univariate] : zip_view(extended_univariates.get_all(), base_univariates)) { - extended_univariate = base_univariate.template extend_to(); + extended_univariate = base_univariate.template extend_to(); } } + /** + * @brief Add the value of each relation over univariates to an appropriate accumulator + * + * @tparam Parameters relation parameters type + * @tparam relation_idx The index of the relation + */ template void accumulate_relation_univariates(TupleOfTuplesOfUnivariates& univariate_accumulators, const ExtendedUnivariates& extended_univariates, @@ -294,39 +328,73 @@ template class ProtoGalaxyProver_ { // Repeat for the next relation. if constexpr (relation_idx + 1 < Flavor::NUM_RELATIONS) { - accumulate_relation_univariates( - univariate_accumulators, extended_univariates, relation_parameters, scaling_factor); + accumulate_relation_univariates< + + Parameters, + relation_idx + 1>(univariate_accumulators, extended_univariates, relation_parameters, scaling_factor); } } /** - * @brief Compute the combiner polynomial $G$ in the Protogalaxy paper. + * @brief Add the value of each relation over univariates to an appropriate accumulator with index skipping + * optimisation + * + * @tparam Parameters relation parameters type + * @tparam relation_idx The index of the relation + */ + template + void accumulate_relation_univariates(OptimisedTupleOfTuplesOfUnivariates& univariate_accumulators, + const OptimisedExtendedUnivariates& extended_univariates, + const Parameters& relation_parameters, + const FF& scaling_factor) + { + using Relation = std::tuple_element_t; + Relation::accumulate( + std::get(univariate_accumulators), extended_univariates, relation_parameters, scaling_factor); + + // Repeat for the next relation. + if constexpr (relation_idx + 1 < Flavor::NUM_RELATIONS) { + accumulate_relation_univariates< + + Parameters, + relation_idx + 1>(univariate_accumulators, extended_univariates, relation_parameters, scaling_factor); + } + } + /** + * @brief Compute the combiner polynomial $G$ in the Protogalaxy paper * */ + template = true> ExtendedUnivariateWithRandomization compute_combiner(const ProverInstances& instances, PowPolynomial& pow_betas) { - BB_OP_COUNT_TIME(); size_t common_instance_size = instances[0]->proving_key.circuit_size; pow_betas.compute_values(); // Determine number of threads for multithreading. // Note: Multithreading is "on" for every round but we reduce the number of threads from the max available based - // on a specified minimum number of iterations per thread. This eventually leads to the use of a single thread. - // For now we use a power of 2 number of threads simply to ensure the round size is evenly divided. + // on a specified minimum number of iterations per thread. This eventually leads to the use of a + // single thread. For now we use a power of 2 number of threads simply to ensure the round size is evenly + // divided. size_t max_num_threads = get_num_cpus_pow2(); // number of available threads (power of 2) size_t min_iterations_per_thread = 1 << 6; // min number of iterations for which we'll spin up a unique thread size_t desired_num_threads = common_instance_size / min_iterations_per_thread; size_t num_threads = std::min(desired_num_threads, max_num_threads); // fewer than max if justified num_threads = num_threads > 0 ? num_threads : 1; // ensure num threads is >= 1 size_t iterations_per_thread = common_instance_size / num_threads; // actual iterations per thread + + // Univariates are optimised for usual PG, but we need the unoptimised version for tests (it's a version that + // doesn't skip computation), so we need to define types depending on the template instantiation + using ThreadAccumulators = TupleOfTuplesOfUnivariates; + using ExtendedUnivatiatesType = ExtendedUnivariates; + // Construct univariate accumulator containers; one per thread - std::vector thread_univariate_accumulators(num_threads); + std::vector thread_univariate_accumulators(num_threads); for (auto& accum : thread_univariate_accumulators) { // just normal relation lengths Utils::zero_univariates(accum); } // Construct extended univariates containers; one per thread - std::vector extended_univariates; + std::vector extended_univariates; extended_univariates.resize(num_threads); // Accumulate the contribution from each sub-relation @@ -335,14 +403,15 @@ template class ProtoGalaxyProver_ { size_t end = (thread_idx + 1) * iterations_per_thread; for (size_t idx = start; idx < end; idx++) { - // No need to initialise extended_univariates to 0, it's assigned to + extend_univariates(extended_univariates[thread_idx], instances, idx); FF pow_challenge = pow_betas[idx]; - // Accumulate the i-th row's univariate contribution. Note that the relation parameters passed to this - // function have already been folded. Moreover, linear-dependent relations that act over the entire - // execution trace rather than on rows, will not be multiplied by the pow challenge. + // Accumulate the i-th row's univariate contribution. Note that the relation parameters passed to + // this function have already been folded. Moreover, linear-dependent relations that act over the + // entire execution trace rather than on rows, will not be multiplied by the pow challenge. + accumulate_relation_univariates( thread_univariate_accumulators[thread_idx], extended_univariates[thread_idx], @@ -350,19 +419,115 @@ template class ProtoGalaxyProver_ { pow_challenge); } }); - + Utils::zero_univariates(univariate_accumulators); // Accumulate the per-thread univariate accumulators into a single set of accumulators for (auto& accumulators : thread_univariate_accumulators) { Utils::add_nested_tuples(univariate_accumulators, accumulators); } - // Batch the univariate contributions from each sub-relation to obtain the round univariate + return batch_over_relations(univariate_accumulators, instances.alphas); } + /** + * @brief Compute the combiner polynomial $G$ in the Protogalaxy paper using indice skippping optimisation + * + * @todo (https://github.com/AztecProtocol/barretenberg/issues/968) Make combiner tests better + * + */ + template = true> + ExtendedUnivariateWithRandomization compute_combiner(const ProverInstances& instances, PowPolynomial& pow_betas) + { + BB_OP_COUNT_TIME(); + size_t common_instance_size = instances[0]->proving_key.circuit_size; + pow_betas.compute_values(); + // Determine number of threads for multithreading. + // Note: Multithreading is "on" for every round but we reduce the number of threads from the max available based + // on a specified minimum number of iterations per thread. This eventually leads to the use of a + // single thread. For now we use a power of 2 number of threads simply to ensure the round size is evenly + // divided. + size_t max_num_threads = get_num_cpus_pow2(); // number of available threads (power of 2) + size_t min_iterations_per_thread = 1 << 6; // min number of iterations for which we'll spin up a unique thread + size_t desired_num_threads = common_instance_size / min_iterations_per_thread; + size_t num_threads = std::min(desired_num_threads, max_num_threads); // fewer than max if justified + num_threads = num_threads > 0 ? num_threads : 1; // ensure num threads is >= 1 + size_t iterations_per_thread = common_instance_size / num_threads; // actual iterations per thread + + // Univariates are optimised for usual PG, but we need the unoptimised version for tests (it's a version that + // doesn't skip computation), so we need to define types depending on the template instantiation + using ThreadAccumulators = OptimisedTupleOfTuplesOfUnivariates; + using ExtendedUnivatiatesType = OptimisedExtendedUnivariates; + + // Construct univariate accumulator containers; one per thread + std::vector thread_univariate_accumulators(num_threads); + for (auto& accum : thread_univariate_accumulators) { + // just normal relation lengths + Utils::zero_univariates(accum); + } + + // Construct extended univariates containers; one per thread + std::vector extended_univariates; + extended_univariates.resize(num_threads); + + // Accumulate the contribution from each sub-relation + parallel_for(num_threads, [&](size_t thread_idx) { + size_t start = thread_idx * iterations_per_thread; + size_t end = (thread_idx + 1) * iterations_per_thread; + + for (size_t idx = start; idx < end; idx++) { + // No need to initialise extended_univariates to 0, it's assigned to + // Instantiate univariates with skipping to ignore computation in those indices (they are still + // available for skipping relations, but all derived univariate will ignore those evaluations) + extend_univariates( + extended_univariates[thread_idx], instances, idx); + + FF pow_challenge = pow_betas[idx]; + + // Accumulate the i-th row's univariate contribution. Note that the relation parameters passed to + // this function have already been folded. Moreover, linear-dependent relations that act over the + // entire execution trace rather than on rows, will not be multiplied by the pow challenge. + accumulate_relation_univariates( + thread_univariate_accumulators[thread_idx], + extended_univariates[thread_idx], + instances.optimised_relation_parameters, // these parameters have already been folded + pow_challenge); + } + }); + Utils::zero_univariates(optimised_univariate_accumulators); + // Accumulate the per-thread univariate accumulators into a single set of accumulators + for (auto& accumulators : thread_univariate_accumulators) { + Utils::add_nested_tuples(optimised_univariate_accumulators, accumulators); + } + + // Convert from optimised version to non-optimised + deoptimise_univariates(optimised_univariate_accumulators, univariate_accumulators); + // Batch the univariate contributions from each sub-relation to obtain the round univariate + return batch_over_relations(univariate_accumulators, instances.alphas); + } + + /** + * @brief Convert univariates from optimised form to regular + * + * @details We need to convert before we batch relations, since optimised versions don't have enough information to + * extend the univariates to maximum length + * + * @param optimised_univariate_accumulators + * @param new_univariate_accumulators + */ + static void deoptimise_univariates(const OptimisedTupleOfTuplesOfUnivariates& optimised_univariate_accumulators, + TupleOfTuplesOfUnivariates& new_univariate_accumulators + + ) + { + auto deoptimise = [&](auto& element) { + auto& optimised_element = std::get(std::get(optimised_univariate_accumulators)); + element = optimised_element.convert(); + }; + + Utils::template apply_to_tuple_of_tuples<0, 0>(new_univariate_accumulators, deoptimise); + } static ExtendedUnivariateWithRandomization batch_over_relations(TupleOfTuplesOfUnivariates& univariate_accumulators, const CombinedRelationSeparator& alpha) { - // First relation does not get multiplied by a batching challenge auto result = std::get<0>(std::get<0>(univariate_accumulators)) .template extend_to(); @@ -432,7 +597,8 @@ template class ProtoGalaxyProver_ { { size_t param_idx = 0; auto to_fold = instances.relation_parameters.get_to_fold(); - for (auto& folded_parameter : to_fold) { + auto to_fold_optimised = instances.optimised_relation_parameters.get_to_fold(); + for (auto [folded_parameter, optimised_folded_parameter] : zip_view(to_fold, to_fold_optimised)) { Univariate tmp(0); size_t instance_idx = 0; for (auto& instance : instances) { @@ -440,6 +606,8 @@ template class ProtoGalaxyProver_ { instance_idx++; } folded_parameter = tmp.template extend_to(); + optimised_folded_parameter = + tmp.template extend_to(); param_idx++; } } diff --git a/barretenberg/cpp/src/barretenberg/relations/nested_containers.hpp b/barretenberg/cpp/src/barretenberg/relations/nested_containers.hpp index 46f2d2463035..36a522eb1618 100644 --- a/barretenberg/cpp/src/barretenberg/relations/nested_containers.hpp +++ b/barretenberg/cpp/src/barretenberg/relations/nested_containers.hpp @@ -10,30 +10,42 @@ namespace bb { * * @details Credit: https://stackoverflow.com/a/60440611 */ -template