Skip to content

Commit

Permalink
Merge e55467d into 98ba747
Browse files Browse the repository at this point in the history
  • Loading branch information
alexghr authored Dec 11, 2024
2 parents 98ba747 + e55467d commit 32ace1c
Show file tree
Hide file tree
Showing 11 changed files with 483 additions and 174 deletions.
10 changes: 8 additions & 2 deletions yarn-project/circuit-types/src/interfaces/prover-broker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,16 @@ export interface ProvingJobProducer {
enqueueProvingJob(job: ProvingJob): Promise<void>;

/**
* Cancels a proving job and clears all of its
* Cancels a proving job.
* @param id - The ID of the job to cancel
*/
removeAndCancelProvingJob(id: ProvingJobId): Promise<void>;
cancelProvingJob(id: ProvingJobId): Promise<void>;

/**
* Cleans up after a job has completed. Throws if the job is in-progress
* @param id - The ID of the job to cancel
*/
cleanUpProvingJobState(id: ProvingJobId): Promise<void>;

/**
* Returns the current status fof the proving job
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class MockProvingJobSource implements ProvingJobSource {
id: 'a-job-id',
type: ProvingRequestType.PRIVATE_BASE_ROLLUP,
inputsUri: 'inputs-uri' as ProofUri,
epochNumber: 1,
});
}
heartbeat(jobId: string): Promise<void> {
Expand Down
2 changes: 1 addition & 1 deletion yarn-project/circuit-types/src/interfaces/proving-job.ts
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ export type ProvingJobId = z.infer<typeof ProvingJobId>;
export const ProvingJob = z.object({
id: ProvingJobId,
type: z.nativeEnum(ProvingRequestType),
blockNumber: z.number().optional(),
epochNumber: z.number(),
inputsUri: ProofUri,
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ export class MemoryProvingQueue implements ServerCircuitProver, ProvingJobSource
id: job.id,
type: job.type,
inputsUri: job.inputsUri,
epochNumber: job.epochNumber,
};
} catch (err) {
if (err instanceof TimeoutError) {
Expand Down Expand Up @@ -244,7 +245,7 @@ export class MemoryProvingQueue implements ServerCircuitProver, ProvingJobSource
reject,
attempts: 1,
heartbeat: 0,
epochNumber,
epochNumber: epochNumber ?? 0,
};

if (signal) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ describe('CachingBrokerFacade', () => {
broker = mock<ProvingJobProducer>({
enqueueProvingJob: jest.fn<any>(),
getProvingJobStatus: jest.fn<any>(),
removeAndCancelProvingJob: jest.fn<any>(),
cancelProvingJob: jest.fn<any>(),
cleanUpProvingJobState: jest.fn<any>(),
waitForJobToSettle: jest.fn<any>(),
});
cache = new InMemoryProverCache();
Expand Down Expand Up @@ -101,4 +102,55 @@ describe('CachingBrokerFacade', () => {
await expect(facade.getBaseParityProof(inputs)).resolves.toEqual(result);
expect(broker.enqueueProvingJob).toHaveBeenCalledTimes(1); // job was only ever enqueued once
});

it('clears broker state after a job resolves', async () => {
const { promise, resolve } = promiseWithResolvers<any>();
broker.enqueueProvingJob.mockResolvedValue(Promise.resolve());
broker.waitForJobToSettle.mockResolvedValue(promise);

const inputs = makeBaseParityInputs();
void facade.getBaseParityProof(inputs);
await jest.advanceTimersToNextTimerAsync();

const job = broker.enqueueProvingJob.mock.calls[0][0];
const result = makePublicInputsAndRecursiveProof(
makeParityPublicInputs(),
makeRecursiveProof(RECURSIVE_PROOF_LENGTH),
VerificationKeyData.makeFakeHonk(),
);
const outputUri = await proofStore.saveProofOutput(job.id, ProvingRequestType.BASE_PARITY, result);
resolve({
status: 'fulfilled',
value: outputUri,
});

await jest.advanceTimersToNextTimerAsync();
expect(broker.cleanUpProvingJobState).toHaveBeenCalled();
});

it('clears broker state after a job is canceled', async () => {
const { promise, resolve } = promiseWithResolvers<any>();
const catchSpy = jest.fn();
broker.enqueueProvingJob.mockResolvedValue(Promise.resolve());
broker.waitForJobToSettle.mockResolvedValue(promise);

const inputs = makeBaseParityInputs();
const controller = new AbortController();
void facade.getBaseParityProof(inputs, controller.signal).catch(catchSpy);
await jest.advanceTimersToNextTimerAsync();

expect(broker.cancelProvingJob).not.toHaveBeenCalled();
controller.abort();
await jest.advanceTimersToNextTimerAsync();
expect(broker.cancelProvingJob).toHaveBeenCalled();

resolve({
status: 'rejected',
reason: 'Aborted',
});

await jest.advanceTimersToNextTimerAsync();
expect(broker.cleanUpProvingJobState).toHaveBeenCalled();
expect(catchSpy).toHaveBeenCalledWith(new Error('Aborted'));
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ export class CachingBrokerFacade implements ServerCircuitProver {
id: ProvingJobId,
type: T,
inputs: ProvingJobInputsMap[T],
epochNumber = 0,
signal?: AbortSignal,
): Promise<ProvingJobResultsMap[T]> {
// first try the cache
Expand Down Expand Up @@ -95,6 +96,7 @@ export class CachingBrokerFacade implements ServerCircuitProver {
id,
type,
inputsUri,
epochNumber,
});
await this.cache.setProvingJobStatus(id, { status: 'in-queue' });
} catch (err) {
Expand All @@ -107,7 +109,7 @@ export class CachingBrokerFacade implements ServerCircuitProver {
// notify broker of cancelled job
const abortFn = async () => {
signal?.removeEventListener('abort', abortFn);
await this.broker.removeAndCancelProvingJob(id);
await this.broker.cancelProvingJob(id);
};

signal?.addEventListener('abort', abortFn);
Expand Down Expand Up @@ -147,160 +149,174 @@ export class CachingBrokerFacade implements ServerCircuitProver {
}
} finally {
signal?.removeEventListener('abort', abortFn);
// we've saved the result in our cache. We can tell the broker to clear its state
await this.broker.cleanUpProvingJobState(id);
}
}

getAvmProof(
inputs: AvmCircuitInputs,
signal?: AbortSignal,
_blockNumber?: number,
epochNumber?: number,
): Promise<ProofAndVerificationKey<typeof AVM_PROOF_LENGTH_IN_FIELDS>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.PUBLIC_VM, inputs),
ProvingRequestType.PUBLIC_VM,
inputs,
epochNumber,
signal,
);
}

getBaseParityProof(
inputs: BaseParityInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.BASE_PARITY, inputs),
ProvingRequestType.BASE_PARITY,
inputs,
epochNumber,
signal,
);
}

getBlockMergeRollupProof(
input: BlockMergeRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<BlockRootOrBlockMergePublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.BLOCK_MERGE_ROLLUP, input),
ProvingRequestType.BLOCK_MERGE_ROLLUP,
input,
epochNumber,
signal,
);
}

getBlockRootRollupProof(
input: BlockRootRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<BlockRootOrBlockMergePublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.BLOCK_ROOT_ROLLUP, input),
ProvingRequestType.BLOCK_ROOT_ROLLUP,
input,
epochNumber,
signal,
);
}

getEmptyBlockRootRollupProof(
input: EmptyBlockRootRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<BlockRootOrBlockMergePublicInputs>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.EMPTY_BLOCK_ROOT_ROLLUP, input),
ProvingRequestType.EMPTY_BLOCK_ROOT_ROLLUP,
input,
epochNumber,
signal,
);
}

getEmptyPrivateKernelProof(
inputs: PrivateKernelEmptyInputData,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<KernelCircuitPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.PRIVATE_KERNEL_EMPTY, inputs),
ProvingRequestType.PRIVATE_KERNEL_EMPTY,
inputs,
epochNumber,
signal,
);
}

getMergeRollupProof(
input: MergeRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<BaseOrMergeRollupPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.MERGE_ROLLUP, input),
ProvingRequestType.MERGE_ROLLUP,
input,
epochNumber,
signal,
);
}
getPrivateBaseRollupProof(
baseRollupInput: PrivateBaseRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<BaseOrMergeRollupPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.PRIVATE_BASE_ROLLUP, baseRollupInput),
ProvingRequestType.PRIVATE_BASE_ROLLUP,
baseRollupInput,
epochNumber,
signal,
);
}

getPublicBaseRollupProof(
inputs: PublicBaseRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<BaseOrMergeRollupPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.PUBLIC_BASE_ROLLUP, inputs),
ProvingRequestType.PUBLIC_BASE_ROLLUP,
inputs,
epochNumber,
signal,
);
}

getRootParityProof(
inputs: RootParityInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof NESTED_RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.ROOT_PARITY, inputs),
ProvingRequestType.ROOT_PARITY,
inputs,
epochNumber,
signal,
);
}

getRootRollupProof(
input: RootRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<RootRollupPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.ROOT_ROLLUP, input),
ProvingRequestType.ROOT_ROLLUP,
input,
epochNumber,
signal,
);
}

getTubeProof(
tubeInput: TubeInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<ProofAndVerificationKey<typeof TUBE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.TUBE_PROOF, tubeInput),
ProvingRequestType.TUBE_PROOF,
tubeInput,
epochNumber,
signal,
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ describe('ProvingAgent', () => {
const inputs: ProvingJobInputs = { type: ProvingRequestType.BASE_PARITY, inputs: makeBaseParityInputs() };
const job: ProvingJob = {
id: randomBytes(8).toString('hex') as ProvingJobId,
blockNumber: 1,
epochNumber: 1,
type: ProvingRequestType.BASE_PARITY,
inputsUri: randomBytes(8).toString('hex') as ProofUri,
};
Expand Down
Loading

0 comments on commit 32ace1c

Please sign in to comment.