diff --git a/yarn-project/bb-prover/src/bb/execute.ts b/yarn-project/bb-prover/src/bb/execute.ts index 11ef630a42a..e5159fb55e7 100644 --- a/yarn-project/bb-prover/src/bb/execute.ts +++ b/yarn-project/bb-prover/src/bb/execute.ts @@ -43,6 +43,7 @@ export type BBSuccess = { export type BBFailure = { status: BB_RESULT.FAILURE; reason: string; + retry?: boolean; }; export type BBResult = BBSuccess | BBFailure; @@ -175,6 +176,7 @@ export async function generateKeyForNoirCircuit( return { status: BB_RESULT.FAILURE, reason: `Failed to generate key. Exit code: ${result.exitCode}. Signal ${result.signal}.`, + retry: !!result.signal, }; } catch (error) { return { status: BB_RESULT.FAILURE, reason: `${error}` }; @@ -245,6 +247,7 @@ export async function executeBbClientIvcProof( return { status: BB_RESULT.FAILURE, reason: `Failed to generate proof. Exit code ${result.exitCode}. Signal ${result.signal}.`, + retry: !!result.signal, }; } catch (error) { return { status: BB_RESULT.FAILURE, reason: `${error}` }; @@ -324,6 +327,7 @@ export async function computeVerificationKey( return { status: BB_RESULT.FAILURE, reason: `Failed to write VK. Exit code ${result.exitCode}. Signal ${result.signal}.`, + retry: !!result.signal, }; } catch (error) { return { status: BB_RESULT.FAILURE, reason: `${error}` }; @@ -396,6 +400,7 @@ export async function generateProof( return { status: BB_RESULT.FAILURE, reason: `Failed to generate proof. Exit code ${result.exitCode}. Signal ${result.signal}.`, + retry: !!result.signal, }; } catch (error) { return { status: BB_RESULT.FAILURE, reason: `${error}` }; @@ -470,6 +475,7 @@ export async function generateTubeProof( return { status: BB_RESULT.FAILURE, reason: `Failed to generate proof. Exit code ${result.exitCode}. Signal ${result.signal}.`, + retry: !!result.signal, }; } catch (error) { return { status: BB_RESULT.FAILURE, reason: `${error}` }; @@ -573,6 +579,7 @@ export async function generateAvmProof( return { status: BB_RESULT.FAILURE, reason: `Failed to generate proof. Exit code ${result.exitCode}. Signal ${result.signal}.`, + retry: !!result.signal, }; } catch (error) { return { status: BB_RESULT.FAILURE, reason: `${error}` }; @@ -648,6 +655,7 @@ export async function verifyClientIvcProof( return { status: BB_RESULT.FAILURE, reason: `Failed to verify proof. Exit code ${result.exitCode}. Signal ${result.signal}.`, + retry: !!result.signal, }; } catch (error) { return { status: BB_RESULT.FAILURE, reason: `${error}` }; @@ -690,6 +698,7 @@ async function verifyProofInternal( return { status: BB_RESULT.FAILURE, reason: `Failed to verify proof. Exit code ${result.exitCode}. Signal ${result.signal}.`, + retry: !!result.signal, }; } catch (error) { return { status: BB_RESULT.FAILURE, reason: `${error}` }; @@ -730,6 +739,7 @@ export async function writeVkAsFields( return { status: BB_RESULT.FAILURE, reason: `Failed to create vk as fields. Exit code ${result.exitCode}. Signal ${result.signal}.`, + retry: !!result.signal, }; } catch (error) { return { status: BB_RESULT.FAILURE, reason: `${error}` }; @@ -772,6 +782,7 @@ export async function writeProofAsFields( return { status: BB_RESULT.FAILURE, reason: `Failed to create proof as fields. Exit code ${result.exitCode}. Signal ${result.signal}.`, + retry: !!result.signal, }; } catch (error) { return { status: BB_RESULT.FAILURE, reason: `${error}` }; @@ -813,6 +824,7 @@ export async function generateContractForVerificationKey( return { status: BB_RESULT.FAILURE, reason: `Failed to write verifier contract. Exit code ${result.exitCode}. Signal ${result.signal}.`, + retry: !!result.signal, }; } catch (error) { return { status: BB_RESULT.FAILURE, reason: `${error}` }; diff --git a/yarn-project/bb-prover/src/prover/bb_prover.ts b/yarn-project/bb-prover/src/prover/bb_prover.ts index a8475793cba..d4db2fbd376 100644 --- a/yarn-project/bb-prover/src/prover/bb_prover.ts +++ b/yarn-project/bb-prover/src/prover/bb_prover.ts @@ -1,6 +1,7 @@ /* eslint-disable require-await */ import { type ProofAndVerificationKey, + ProvingError, type PublicInputsAndRecursiveProof, type ServerCircuitProver, makeProofAndVerificationKey, @@ -477,7 +478,7 @@ export class BBNativeRollupProver implements ServerCircuitProver { if (provingResult.status === BB_RESULT.FAILURE) { logger.error(`Failed to generate proof for ${circuitType}: ${provingResult.reason}`); - throw new Error(provingResult.reason); + throw new ProvingError(provingResult.reason, provingResult, provingResult.retry); } // Ensure our vk cache is up to date @@ -538,7 +539,7 @@ export class BBNativeRollupProver implements ServerCircuitProver { if (provingResult.status === BB_RESULT.FAILURE) { logger.error(`Failed to generate AVM proof for ${input.functionName}: ${provingResult.reason}`); - throw new Error(provingResult.reason); + throw new ProvingError(provingResult.reason, provingResult, provingResult.retry); } return provingResult; @@ -555,7 +556,7 @@ export class BBNativeRollupProver implements ServerCircuitProver { if (provingResult.status === BB_RESULT.FAILURE) { logger.error(`Failed to generate proof for tube proof: ${provingResult.reason}`); - throw new Error(provingResult.reason); + throw new ProvingError(provingResult.reason, provingResult, provingResult.retry); } return provingResult; } @@ -724,7 +725,7 @@ export class BBNativeRollupProver implements ServerCircuitProver { if (result.status === BB_RESULT.FAILURE) { const errorMessage = `Failed to verify proof from key!`; - throw new Error(errorMessage); + throw new ProvingError(errorMessage, result, result.retry); } logger.info(`Successfully verified proof from key in ${result.durationMs} ms`); @@ -785,7 +786,7 @@ export class BBNativeRollupProver implements ServerCircuitProver { if (result.status === BB_RESULT.FAILURE) { const errorMessage = `Failed to convert ${circuit} proof to fields, ${result.reason}`; - throw new Error(errorMessage); + throw new ProvingError(errorMessage, result, result.retry); } const proofString = await fs.readFile(path.join(bbWorkingDirectory, PROOF_FIELDS_FILENAME), { @@ -825,7 +826,11 @@ export class BBNativeRollupProver implements ServerCircuitProver { logger.debug, ).then(result => { if (result.status === BB_RESULT.FAILURE) { - throw new Error(`Failed to generate verification key for ${circuitType}, ${result.reason}`); + throw new ProvingError( + `Failed to generate verification key for ${circuitType}, ${result.reason}`, + result, + result.retry, + ); } return extractVkData(result.vkPath!); }); diff --git a/yarn-project/circuit-types/src/index.ts b/yarn-project/circuit-types/src/index.ts index 86eade75e74..00f71f992c1 100644 --- a/yarn-project/circuit-types/src/index.ts +++ b/yarn-project/circuit-types/src/index.ts @@ -22,3 +22,4 @@ export * from './simulation_error.js'; export * from './tx/index.js'; export * from './tx_effect.js'; export * from './tx_execution_request.js'; +export * from './proving_error.js'; diff --git a/yarn-project/circuit-types/src/proving_error.ts b/yarn-project/circuit-types/src/proving_error.ts new file mode 100644 index 00000000000..7207270958d --- /dev/null +++ b/yarn-project/circuit-types/src/proving_error.ts @@ -0,0 +1,18 @@ +/** + * An error thrown when generating a proof fails. + */ +export class ProvingError extends Error { + public static readonly NAME = 'ProvingError'; + + /** + * Creates a new instance + * @param message - The error message. + * @param cause - The cause of the error. + * @param retry - Whether the proof should be retried. + */ + constructor(message: string, cause?: unknown, public readonly retry: boolean = false) { + super(message); + this.name = ProvingError.NAME; + this.cause = cause; + } +} diff --git a/yarn-project/prover-client/src/proving_broker/proving_agent.test.ts b/yarn-project/prover-client/src/proving_broker/proving_agent.test.ts new file mode 100644 index 00000000000..9a2c7db1da9 --- /dev/null +++ b/yarn-project/prover-client/src/proving_broker/proving_agent.test.ts @@ -0,0 +1,226 @@ +import { + ProvingError, + ProvingRequestType, + type PublicInputsAndRecursiveProof, + type V2ProvingJob, + type V2ProvingJobId, + makePublicInputsAndRecursiveProof, +} from '@aztec/circuit-types'; +import { + type ParityPublicInputs, + RECURSIVE_PROOF_LENGTH, + VerificationKeyData, + makeRecursiveProof, +} from '@aztec/circuits.js'; +import { makeBaseParityInputs, makeParityPublicInputs } from '@aztec/circuits.js/testing'; +import { randomBytes } from '@aztec/foundation/crypto'; +import { AbortError } from '@aztec/foundation/error'; +import { promiseWithResolvers } from '@aztec/foundation/promise'; + +import { jest } from '@jest/globals'; + +import { MockProver } from '../test/mock_prover.js'; +import { ProvingAgent } from './proving_agent.js'; +import { type ProvingJobConsumer } from './proving_broker_interface.js'; + +describe('ProvingAgent', () => { + let prover: MockProver; + let jobSource: jest.Mocked; + let agent: ProvingAgent; + const agentPollIntervalMs = 1000; + + beforeEach(() => { + jest.useFakeTimers(); + + prover = new MockProver(); + jobSource = { + getProvingJob: jest.fn(), + reportProvingJobProgress: jest.fn(), + reportProvingJobError: jest.fn(), + reportProvingJobSuccess: jest.fn(), + }; + agent = new ProvingAgent(jobSource, prover, [ProvingRequestType.BASE_PARITY]); + }); + + afterEach(async () => { + await agent.stop(); + }); + + it('polls for jobs passing the permitted list of proofs', () => { + agent.start(); + expect(jobSource.getProvingJob).toHaveBeenCalledWith({ allowList: [ProvingRequestType.BASE_PARITY] }); + }); + + it('only takes a single job from the source at a time', async () => { + expect(jobSource.getProvingJob).not.toHaveBeenCalled(); + + // simulate the proof taking a long time + const { promise, resolve } = + promiseWithResolvers>(); + jest.spyOn(prover, 'getBaseParityProof').mockReturnValueOnce(promise); + + const jobResponse = makeBaseParityJob(); + jobSource.getProvingJob.mockResolvedValueOnce(jobResponse); + agent.start(); + + await jest.advanceTimersByTimeAsync(agentPollIntervalMs); + expect(jobSource.getProvingJob).toHaveBeenCalledTimes(1); + + await jest.advanceTimersByTimeAsync(agentPollIntervalMs); + expect(jobSource.getProvingJob).toHaveBeenCalledTimes(1); + + await jest.advanceTimersByTimeAsync(agentPollIntervalMs); + expect(jobSource.getProvingJob).toHaveBeenCalledTimes(1); + + // let's resolve the proof + const result = makePublicInputsAndRecursiveProof( + makeParityPublicInputs(), + makeRecursiveProof(RECURSIVE_PROOF_LENGTH), + VerificationKeyData.makeFakeHonk(), + ); + resolve(result); + + await jest.advanceTimersByTimeAsync(agentPollIntervalMs); + expect(jobSource.getProvingJob).toHaveBeenCalledTimes(2); + }); + + it('reports success to the job source', async () => { + const jobResponse = makeBaseParityJob(); + const result = makeBaseParityResult(); + jest.spyOn(prover, 'getBaseParityProof').mockResolvedValueOnce(result.value); + + jobSource.getProvingJob.mockResolvedValueOnce(jobResponse); + agent.start(); + + await jest.advanceTimersByTimeAsync(agentPollIntervalMs); + expect(jobSource.reportProvingJobSuccess).toHaveBeenCalledWith(jobResponse.job.id, result); + }); + + it('reports errors to the job source', async () => { + const jobResponse = makeBaseParityJob(); + jest.spyOn(prover, 'getBaseParityProof').mockRejectedValueOnce(new Error('test error')); + + jobSource.getProvingJob.mockResolvedValueOnce(jobResponse); + agent.start(); + + await jest.advanceTimersByTimeAsync(agentPollIntervalMs); + expect(jobSource.reportProvingJobError).toHaveBeenCalledWith(jobResponse.job.id, new Error('test error'), false); + }); + + it('sets the retry flag on when reporting an error', async () => { + const jobResponse = makeBaseParityJob(); + const err = new ProvingError('test error', undefined, true); + jest.spyOn(prover, 'getBaseParityProof').mockRejectedValueOnce(err); + + jobSource.getProvingJob.mockResolvedValueOnce(jobResponse); + agent.start(); + + await jest.advanceTimersByTimeAsync(agentPollIntervalMs); + expect(jobSource.reportProvingJobError).toHaveBeenCalledWith(jobResponse.job.id, err, true); + }); + + it('reports jobs in progress to the job source', async () => { + const jobResponse = makeBaseParityJob(); + const { promise, resolve } = + promiseWithResolvers>(); + jest.spyOn(prover, 'getBaseParityProof').mockReturnValueOnce(promise); + + jobSource.getProvingJob.mockResolvedValueOnce(jobResponse); + agent.start(); + + await jest.advanceTimersByTimeAsync(agentPollIntervalMs); + expect(jobSource.reportProvingJobProgress).toHaveBeenCalledWith(jobResponse.job.id, jobResponse.time, { + allowList: [ProvingRequestType.BASE_PARITY], + }); + + await jest.advanceTimersByTimeAsync(agentPollIntervalMs); + expect(jobSource.reportProvingJobProgress).toHaveBeenCalledWith(jobResponse.job.id, jobResponse.time, { + allowList: [ProvingRequestType.BASE_PARITY], + }); + + resolve(makeBaseParityResult().value); + }); + + it('abandons jobs if told so by the source', async () => { + const firstJobResponse = makeBaseParityJob(); + let firstProofAborted = false; + const firstProof = + promiseWithResolvers>(); + + // simulate a long running proving job that can be aborted + jest.spyOn(prover, 'getBaseParityProof').mockImplementationOnce((_, signal) => { + signal?.addEventListener('abort', () => { + firstProof.reject(new AbortError('test abort')); + firstProofAborted = true; + }); + return firstProof.promise; + }); + + jobSource.getProvingJob.mockResolvedValueOnce(firstJobResponse); + agent.start(); + + // now the agent should be happily proving and reporting progress + await jest.advanceTimersByTimeAsync(agentPollIntervalMs); + expect(jobSource.reportProvingJobProgress).toHaveBeenCalledTimes(1); + expect(jobSource.reportProvingJobProgress).toHaveBeenCalledWith(firstJobResponse.job.id, firstJobResponse.time, { + allowList: [ProvingRequestType.BASE_PARITY], + }); + + await jest.advanceTimersByTimeAsync(agentPollIntervalMs); + expect(jobSource.reportProvingJobProgress).toHaveBeenCalledTimes(2); + + // now let's simulate the job source cancelling the job and giving the agent something else to do + // this should cause the agent to abort the current job and start the new one + const secondJobResponse = makeBaseParityJob(); + jobSource.reportProvingJobProgress.mockResolvedValueOnce(secondJobResponse); + + const secondProof = + promiseWithResolvers>(); + jest.spyOn(prover, 'getBaseParityProof').mockReturnValueOnce(secondProof.promise); + + await jest.advanceTimersByTimeAsync(agentPollIntervalMs); + expect(jobSource.reportProvingJobProgress).toHaveBeenCalledTimes(3); + expect(jobSource.reportProvingJobProgress).toHaveBeenLastCalledWith( + firstJobResponse.job.id, + firstJobResponse.time, + { + allowList: [ProvingRequestType.BASE_PARITY], + }, + ); + expect(firstProofAborted).toBe(true); + + // agent should have switched now + await jest.advanceTimersByTimeAsync(agentPollIntervalMs); + expect(jobSource.reportProvingJobProgress).toHaveBeenCalledTimes(4); + expect(jobSource.reportProvingJobProgress).toHaveBeenLastCalledWith( + secondJobResponse.job.id, + secondJobResponse.time, + { + allowList: [ProvingRequestType.BASE_PARITY], + }, + ); + + secondProof.resolve(makeBaseParityResult().value); + }); + + function makeBaseParityJob(): { job: V2ProvingJob; time: number } { + const time = jest.now(); + const job: V2ProvingJob = { + id: randomBytes(8).toString('hex') as V2ProvingJobId, + blockNumber: 1, + type: ProvingRequestType.BASE_PARITY, + inputs: makeBaseParityInputs(), + }; + + return { job, time }; + } + + function makeBaseParityResult() { + const value = makePublicInputsAndRecursiveProof( + makeParityPublicInputs(), + makeRecursiveProof(RECURSIVE_PROOF_LENGTH), + VerificationKeyData.makeFakeHonk(), + ); + return { type: ProvingRequestType.BASE_PARITY, value }; + } +}); diff --git a/yarn-project/prover-client/src/proving_broker/proving_agent.ts b/yarn-project/prover-client/src/proving_broker/proving_agent.ts new file mode 100644 index 00000000000..5ee86900e0d --- /dev/null +++ b/yarn-project/prover-client/src/proving_broker/proving_agent.ts @@ -0,0 +1,90 @@ +import { + ProvingError, + type ProvingRequestType, + type ServerCircuitProver, + type V2ProvingJob, +} from '@aztec/circuit-types'; +import { createDebugLogger } from '@aztec/foundation/log'; +import { RunningPromise } from '@aztec/foundation/running-promise'; + +import { type ProvingJobConsumer } from './proving_broker_interface.js'; +import { ProvingJobController, ProvingJobStatus } from './proving_job_controller.js'; + +/** + * A helper class that encapsulates a circuit prover and connects it to a job source. + */ +export class ProvingAgent { + private currentJobController?: ProvingJobController; + private runningPromise: RunningPromise; + + constructor( + /** The source of proving jobs */ + private jobSource: ProvingJobConsumer, + /** The prover implementation to defer jobs to */ + private circuitProver: ServerCircuitProver, + /** Optional list of allowed proof types to build */ + private proofAllowList?: Array, + /** How long to wait between jobs */ + private pollIntervalMs = 1000, + private log = createDebugLogger('aztec:proving-broker:proving-agent'), + ) { + this.runningPromise = new RunningPromise(this.safeWork, this.pollIntervalMs); + } + + public setCircuitProver(circuitProver: ServerCircuitProver): void { + this.circuitProver = circuitProver; + } + + public isRunning(): boolean { + return this.runningPromise?.isRunning() ?? false; + } + + public start(): void { + this.runningPromise.start(); + } + + public async stop(): Promise { + this.currentJobController?.abort(); + await this.runningPromise.stop(); + } + + private safeWork = async () => { + try { + // every tick we need to + // (1) either do a heartbeat, telling the broker that we're working + // (2) get a new job + // If during (1) the broker returns a new job that means we can cancel the current job and start the new one + let maybeJob: { job: V2ProvingJob; time: number } | undefined; + if (this.currentJobController?.getStatus() === ProvingJobStatus.PROVING) { + maybeJob = await this.jobSource.reportProvingJobProgress( + this.currentJobController.getJobId(), + this.currentJobController.getStartedAt(), + { allowList: this.proofAllowList }, + ); + } else { + maybeJob = await this.jobSource.getProvingJob({ allowList: this.proofAllowList }); + } + + if (!maybeJob) { + return; + } + + if (this.currentJobController?.getStatus() === ProvingJobStatus.PROVING) { + this.currentJobController?.abort(); + } + + const { job, time } = maybeJob; + this.currentJobController = new ProvingJobController(job, time, this.circuitProver, (err, result) => { + if (err) { + const retry = err.name === ProvingError.NAME ? (err as ProvingError).retry : false; + return this.jobSource.reportProvingJobError(job.id, err, retry); + } else if (result) { + return this.jobSource.reportProvingJobSuccess(job.id, result); + } + }); + this.currentJobController.start(); + } catch (err) { + this.log.error(`Error in ProvingAgent: ${String(err)}`); + } + }; +} diff --git a/yarn-project/prover-client/src/proving_broker/proving_job_controller.test.ts b/yarn-project/prover-client/src/proving_broker/proving_job_controller.test.ts new file mode 100644 index 00000000000..724d1d4606f --- /dev/null +++ b/yarn-project/prover-client/src/proving_broker/proving_job_controller.test.ts @@ -0,0 +1,91 @@ +import { ProvingRequestType, type V2ProvingJobId, makePublicInputsAndRecursiveProof } from '@aztec/circuit-types'; +import { RECURSIVE_PROOF_LENGTH, VerificationKeyData, makeRecursiveProof } from '@aztec/circuits.js'; +import { makeBaseParityInputs, makeParityPublicInputs } from '@aztec/circuits.js/testing'; +import { sleep } from '@aztec/foundation/sleep'; + +import { jest } from '@jest/globals'; + +import { MockProver } from '../test/mock_prover.js'; +import { ProvingJobController, ProvingJobStatus } from './proving_job_controller.js'; + +describe('ProvingJobController', () => { + let prover: MockProver; + let onComplete: jest.Mock; + let controller: ProvingJobController; + + beforeEach(() => { + prover = new MockProver(); + onComplete = jest.fn(); + controller = new ProvingJobController( + { + type: ProvingRequestType.BASE_PARITY, + blockNumber: 1, + id: '1' as V2ProvingJobId, + inputs: makeBaseParityInputs(), + }, + 0, + prover, + onComplete, + ); + }); + + it('reports IDLE status initially', () => { + expect(controller.getStatus()).toBe(ProvingJobStatus.IDLE); + }); + + it('reports PROVING status while busy', () => { + controller.start(); + expect(controller.getStatus()).toBe(ProvingJobStatus.PROVING); + }); + + it('reports DONE status after job is done', async () => { + controller.start(); + await sleep(1); // give promises a chance to complete + expect(controller.getStatus()).toBe(ProvingJobStatus.DONE); + }); + + it('calls onComplete with the proof', async () => { + const resp = makePublicInputsAndRecursiveProof( + makeParityPublicInputs(), + makeRecursiveProof(RECURSIVE_PROOF_LENGTH), + VerificationKeyData.makeFakeHonk(), + ); + jest.spyOn(prover, 'getBaseParityProof').mockResolvedValueOnce(resp); + + controller.start(); + await sleep(1); // give promises a chance to complete + expect(onComplete).toHaveBeenCalledWith(undefined, { + type: ProvingRequestType.BASE_PARITY, + value: resp, + }); + }); + + it('calls onComplete with the error', async () => { + const err = new Error('test error'); + jest.spyOn(prover, 'getBaseParityProof').mockRejectedValueOnce(err); + + controller.start(); + await sleep(1); + expect(onComplete).toHaveBeenCalledWith(err, undefined); + }); + + it('does not crash if onComplete throws', async () => { + const err = new Error('test error'); + onComplete.mockImplementationOnce(() => { + throw err; + }); + + controller.start(); + await sleep(1); + expect(onComplete).toHaveBeenCalled(); + }); + + it('does not crash if onComplete rejects', async () => { + const err = new Error('test error'); + onComplete.mockRejectedValueOnce(err); + + controller.start(); + await sleep(1); + expect(onComplete).toHaveBeenCalled(); + }); +}); diff --git a/yarn-project/prover-client/src/proving_broker/proving_job_controller.ts b/yarn-project/prover-client/src/proving_broker/proving_job_controller.ts new file mode 100644 index 00000000000..53d18b476a0 --- /dev/null +++ b/yarn-project/prover-client/src/proving_broker/proving_job_controller.ts @@ -0,0 +1,148 @@ +import { + ProvingRequestType, + type ServerCircuitProver, + type V2ProofOutput, + type V2ProvingJob, + type V2ProvingJobId, +} from '@aztec/circuit-types'; + +export enum ProvingJobStatus { + IDLE = 'idle', + PROVING = 'proving', + DONE = 'done', +} + +type ProvingJobCompletionCallback = ( + error: Error | undefined, + result: V2ProofOutput | undefined, +) => void | Promise; + +export class ProvingJobController { + private status: ProvingJobStatus = ProvingJobStatus.IDLE; + private promise?: Promise; + private abortController = new AbortController(); + + constructor( + private job: V2ProvingJob, + private startedAt: number, + private circuitProver: ServerCircuitProver, + private onComplete: ProvingJobCompletionCallback, + ) {} + + public start(): void { + if (this.status !== ProvingJobStatus.IDLE) { + return; + } + + this.status = ProvingJobStatus.PROVING; + this.promise = this.generateProof() + .then( + result => { + this.status = ProvingJobStatus.DONE; + return this.onComplete(undefined, result); + }, + error => { + this.status = ProvingJobStatus.DONE; + if (error.name === 'AbortError') { + // Ignore abort errors + return; + } + return this.onComplete(error, undefined); + }, + ) + .catch(_ => { + // ignore completion errors + }); + } + + public getStatus(): ProvingJobStatus { + return this.status; + } + + public abort(): void { + if (this.status !== ProvingJobStatus.PROVING) { + return; + } + + this.abortController.abort(); + } + + public getJobId(): V2ProvingJobId { + return this.job.id; + } + + public getStartedAt(): number { + return this.startedAt; + } + + private async generateProof(): Promise { + const { type, inputs } = this.job; + const signal = this.abortController.signal; + switch (type) { + case ProvingRequestType.PUBLIC_VM: { + const value = await this.circuitProver.getAvmProof(inputs, signal); + return { type, value }; + } + + case ProvingRequestType.PRIVATE_BASE_ROLLUP: { + const value = await this.circuitProver.getPrivateBaseRollupProof(inputs, signal); + return { type, value }; + } + + case ProvingRequestType.PUBLIC_BASE_ROLLUP: { + const value = await this.circuitProver.getPublicBaseRollupProof(inputs, signal); + return { type, value }; + } + + case ProvingRequestType.MERGE_ROLLUP: { + const value = await this.circuitProver.getMergeRollupProof(inputs, signal); + return { type, value }; + } + + case ProvingRequestType.EMPTY_BLOCK_ROOT_ROLLUP: { + const value = await this.circuitProver.getEmptyBlockRootRollupProof(inputs, signal); + return { type, value }; + } + + case ProvingRequestType.BLOCK_ROOT_ROLLUP: { + const value = await this.circuitProver.getBlockRootRollupProof(inputs, signal); + return { type, value }; + } + + case ProvingRequestType.BLOCK_MERGE_ROLLUP: { + const value = await this.circuitProver.getBlockMergeRollupProof(inputs, signal); + return { type, value }; + } + + case ProvingRequestType.ROOT_ROLLUP: { + const value = await this.circuitProver.getRootRollupProof(inputs, signal); + return { type, value }; + } + + case ProvingRequestType.BASE_PARITY: { + const value = await this.circuitProver.getBaseParityProof(inputs, signal); + return { type, value }; + } + + case ProvingRequestType.ROOT_PARITY: { + const value = await this.circuitProver.getRootParityProof(inputs, signal); + return { type, value }; + } + + case ProvingRequestType.PRIVATE_KERNEL_EMPTY: { + const value = await this.circuitProver.getEmptyPrivateKernelProof(inputs, signal); + return { type, value }; + } + + case ProvingRequestType.TUBE_PROOF: { + const value = await this.circuitProver.getTubeProof(inputs, signal); + return { type, value }; + } + + default: { + const _exhaustive: never = type; + return Promise.reject(new Error(`Invalid proof request type: ${type}`)); + } + } + } +} diff --git a/yarn-project/prover-client/src/test/mock_prover.ts b/yarn-project/prover-client/src/test/mock_prover.ts index d4f253fa868..118ff214e14 100644 --- a/yarn-project/prover-client/src/test/mock_prover.ts +++ b/yarn-project/prover-client/src/test/mock_prover.ts @@ -8,11 +8,22 @@ import { import { AVM_PROOF_LENGTH_IN_FIELDS, AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS, + type AvmCircuitInputs, type BaseOrMergeRollupPublicInputs, + type BaseParityInputs, + type BlockMergeRollupInputs, type BlockRootOrBlockMergePublicInputs, + type BlockRootRollupInputs, + type EmptyBlockRootRollupInputs, type KernelCircuitPublicInputs, + type MergeRollupInputs, NESTED_RECURSIVE_PROOF_LENGTH, + type PrivateBaseRollupInputs, + type PrivateKernelEmptyInputData, + type PublicBaseRollupInputs, RECURSIVE_PROOF_LENGTH, + type RootParityInputs, + type RootRollupInputs, type RootRollupPublicInputs, TUBE_PROOF_LENGTH, VerificationKeyData, @@ -30,7 +41,7 @@ import { export class MockProver implements ServerCircuitProver { constructor() {} - getAvmProof() { + getAvmProof(_inputs: AvmCircuitInputs, _signal?: AbortSignal, _epochNumber?: number) { return Promise.resolve( makeProofAndVerificationKey( makeEmptyRecursiveProof(AVM_PROOF_LENGTH_IN_FIELDS), @@ -39,7 +50,7 @@ export class MockProver implements ServerCircuitProver { ); } - getBaseParityProof() { + getBaseParityProof(_inputs: BaseParityInputs, _signal?: AbortSignal, _epochNumber?: number) { return Promise.resolve( makePublicInputsAndRecursiveProof( makeParityPublicInputs(), @@ -49,7 +60,7 @@ export class MockProver implements ServerCircuitProver { ); } - getRootParityProof() { + getRootParityProof(_inputs: RootParityInputs, _signal?: AbortSignal, _epochNumber?: number) { return Promise.resolve( makePublicInputsAndRecursiveProof( makeParityPublicInputs(), @@ -59,7 +70,11 @@ export class MockProver implements ServerCircuitProver { ); } - getPrivateBaseRollupProof(): Promise> { + getPrivateBaseRollupProof( + _baseRollupInput: PrivateBaseRollupInputs, + _signal?: AbortSignal, + _epochNumber?: number, + ): Promise> { return Promise.resolve( makePublicInputsAndRecursiveProof( makeBaseOrMergeRollupPublicInputs(), @@ -69,7 +84,11 @@ export class MockProver implements ServerCircuitProver { ); } - getPublicBaseRollupProof(): Promise> { + getPublicBaseRollupProof( + _inputs: PublicBaseRollupInputs, + _signal?: AbortSignal, + _epochNumber?: number, + ): Promise> { return Promise.resolve( makePublicInputsAndRecursiveProof( makeBaseOrMergeRollupPublicInputs(), @@ -79,7 +98,11 @@ export class MockProver implements ServerCircuitProver { ); } - getMergeRollupProof(): Promise> { + getMergeRollupProof( + _input: MergeRollupInputs, + _signal?: AbortSignal, + _epochNumber?: number, + ): Promise> { return Promise.resolve( makePublicInputsAndRecursiveProof( makeBaseOrMergeRollupPublicInputs(), @@ -89,7 +112,7 @@ export class MockProver implements ServerCircuitProver { ); } - getBlockMergeRollupProof() { + getBlockMergeRollupProof(_input: BlockMergeRollupInputs, _signal?: AbortSignal, _epochNumber?: number) { return Promise.resolve( makePublicInputsAndRecursiveProof( makeBlockRootOrBlockMergeRollupPublicInputs(), @@ -99,7 +122,11 @@ export class MockProver implements ServerCircuitProver { ); } - getEmptyBlockRootRollupProof(): Promise> { + getEmptyBlockRootRollupProof( + _input: EmptyBlockRootRollupInputs, + _signal?: AbortSignal, + _epochNumber?: number, + ): Promise> { return Promise.resolve( makePublicInputsAndRecursiveProof( makeBlockRootOrBlockMergeRollupPublicInputs(), @@ -109,7 +136,11 @@ export class MockProver implements ServerCircuitProver { ); } - getBlockRootRollupProof(): Promise> { + getBlockRootRollupProof( + _input: BlockRootRollupInputs, + _signal?: AbortSignal, + _epochNumber?: number, + ): Promise> { return Promise.resolve( makePublicInputsAndRecursiveProof( makeBlockRootOrBlockMergeRollupPublicInputs(), @@ -119,7 +150,11 @@ export class MockProver implements ServerCircuitProver { ); } - getEmptyPrivateKernelProof(): Promise> { + getEmptyPrivateKernelProof( + _inputs: PrivateKernelEmptyInputData, + _signal?: AbortSignal, + _epochNumber?: number, + ): Promise> { return Promise.resolve( makePublicInputsAndRecursiveProof( makeKernelCircuitPublicInputs(), @@ -129,7 +164,11 @@ export class MockProver implements ServerCircuitProver { ); } - getRootRollupProof(): Promise> { + getRootRollupProof( + _input: RootRollupInputs, + _signal?: AbortSignal, + _epochNumber?: number, + ): Promise> { return Promise.resolve( makePublicInputsAndRecursiveProof( makeRootRollupPublicInputs(),