Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: memory leak in the broker #10567

Merged
merged 4 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions yarn-project/circuit-types/src/interfaces/prover-broker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,16 @@ export interface ProvingJobProducer {
enqueueProvingJob(job: ProvingJob): Promise<void>;

/**
* Cancels a proving job and clears all of its
* Cancels a proving job.
* @param id - The ID of the job to cancel
*/
removeAndCancelProvingJob(id: ProvingJobId): Promise<void>;
cancelProvingJob(id: ProvingJobId): Promise<void>;

/**
* Cleans up after a job has completed. Throws if the job is in-progress
* @param id - The ID of the job to cancel
*/
cleanUpProvingJobState(id: ProvingJobId): Promise<void>;

/**
* Returns the current status fof the proving job
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class MockProvingJobSource implements ProvingJobSource {
id: 'a-job-id',
type: ProvingRequestType.PRIVATE_BASE_ROLLUP,
inputsUri: 'inputs-uri' as ProofUri,
epochNumber: 1,
});
}
heartbeat(jobId: string): Promise<void> {
Expand Down
2 changes: 1 addition & 1 deletion yarn-project/circuit-types/src/interfaces/proving-job.ts
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ export type ProvingJobId = z.infer<typeof ProvingJobId>;
export const ProvingJob = z.object({
id: ProvingJobId,
type: z.nativeEnum(ProvingRequestType),
blockNumber: z.number().optional(),
epochNumber: z.number(),
inputsUri: ProofUri,
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ export class MemoryProvingQueue implements ServerCircuitProver, ProvingJobSource
id: job.id,
type: job.type,
inputsUri: job.inputsUri,
epochNumber: job.epochNumber,
};
} catch (err) {
if (err instanceof TimeoutError) {
Expand Down Expand Up @@ -244,7 +245,7 @@ export class MemoryProvingQueue implements ServerCircuitProver, ProvingJobSource
reject,
attempts: 1,
heartbeat: 0,
epochNumber,
epochNumber: epochNumber ?? 0,
};

if (signal) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ describe('CachingBrokerFacade', () => {
broker = mock<ProvingJobProducer>({
enqueueProvingJob: jest.fn<any>(),
getProvingJobStatus: jest.fn<any>(),
removeAndCancelProvingJob: jest.fn<any>(),
cancelProvingJob: jest.fn<any>(),
cleanUpProvingJobState: jest.fn<any>(),
waitForJobToSettle: jest.fn<any>(),
});
cache = new InMemoryProverCache();
Expand Down Expand Up @@ -101,4 +102,55 @@ describe('CachingBrokerFacade', () => {
await expect(facade.getBaseParityProof(inputs)).resolves.toEqual(result);
expect(broker.enqueueProvingJob).toHaveBeenCalledTimes(1); // job was only ever enqueued once
});

it('clears broker state after a job resolves', async () => {
const { promise, resolve } = promiseWithResolvers<any>();
broker.enqueueProvingJob.mockResolvedValue(Promise.resolve());
broker.waitForJobToSettle.mockResolvedValue(promise);

const inputs = makeBaseParityInputs();
void facade.getBaseParityProof(inputs);
await jest.advanceTimersToNextTimerAsync();

const job = broker.enqueueProvingJob.mock.calls[0][0];
const result = makePublicInputsAndRecursiveProof(
makeParityPublicInputs(),
makeRecursiveProof(RECURSIVE_PROOF_LENGTH),
VerificationKeyData.makeFakeHonk(),
);
const outputUri = await proofStore.saveProofOutput(job.id, ProvingRequestType.BASE_PARITY, result);
resolve({
status: 'fulfilled',
value: outputUri,
});

await jest.advanceTimersToNextTimerAsync();
expect(broker.cleanUpProvingJobState).toHaveBeenCalled();
});

it('clears broker state after a job is canceled', async () => {
const { promise, resolve } = promiseWithResolvers<any>();
const catchSpy = jest.fn();
broker.enqueueProvingJob.mockResolvedValue(Promise.resolve());
broker.waitForJobToSettle.mockResolvedValue(promise);

const inputs = makeBaseParityInputs();
const controller = new AbortController();
void facade.getBaseParityProof(inputs, controller.signal).catch(catchSpy);
await jest.advanceTimersToNextTimerAsync();

expect(broker.cancelProvingJob).not.toHaveBeenCalled();
controller.abort();
await jest.advanceTimersToNextTimerAsync();
expect(broker.cancelProvingJob).toHaveBeenCalled();

resolve({
status: 'rejected',
reason: 'Aborted',
});

await jest.advanceTimersToNextTimerAsync();
expect(broker.cleanUpProvingJobState).toHaveBeenCalled();
expect(catchSpy).toHaveBeenCalledWith(new Error('Aborted'));
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ export class CachingBrokerFacade implements ServerCircuitProver {
id: ProvingJobId,
type: T,
inputs: ProvingJobInputsMap[T],
epochNumber = 0,
signal?: AbortSignal,
): Promise<ProvingJobResultsMap[T]> {
// first try the cache
Expand Down Expand Up @@ -95,6 +96,7 @@ export class CachingBrokerFacade implements ServerCircuitProver {
id,
type,
inputsUri,
epochNumber,
});
await this.cache.setProvingJobStatus(id, { status: 'in-queue' });
} catch (err) {
Expand All @@ -107,7 +109,7 @@ export class CachingBrokerFacade implements ServerCircuitProver {
// notify broker of cancelled job
const abortFn = async () => {
signal?.removeEventListener('abort', abortFn);
await this.broker.removeAndCancelProvingJob(id);
await this.broker.cancelProvingJob(id);
};

signal?.addEventListener('abort', abortFn);
Expand Down Expand Up @@ -147,160 +149,174 @@ export class CachingBrokerFacade implements ServerCircuitProver {
}
} finally {
signal?.removeEventListener('abort', abortFn);
// we've saved the result in our cache. We can tell the broker to clear its state
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remind me: who's expected to be persisting proving data? I wouldn't want to clean up the broker data if we were counting on it to persist work info across crashes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤦 I misplaced my comment sorry about that #10567 (comment)

await this.broker.cleanUpProvingJobState(id);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I read this correctly. Does this mean the orchestrator effectively tells the broker when to clear state. This feels like it's quite coupled. Would a better approach be for the broker to simply e.g. delete all state for epochs < N - 1 when it is asked to prove something for epoch N.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you that's right. I think your suggestion makes a lot of sense!

}
}

getAvmProof(
inputs: AvmCircuitInputs,
signal?: AbortSignal,
_blockNumber?: number,
epochNumber?: number,
): Promise<ProofAndVerificationKey<typeof AVM_PROOF_LENGTH_IN_FIELDS>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.PUBLIC_VM, inputs),
ProvingRequestType.PUBLIC_VM,
inputs,
epochNumber,
signal,
);
}

getBaseParityProof(
inputs: BaseParityInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.BASE_PARITY, inputs),
ProvingRequestType.BASE_PARITY,
inputs,
epochNumber,
signal,
);
}

getBlockMergeRollupProof(
input: BlockMergeRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<BlockRootOrBlockMergePublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.BLOCK_MERGE_ROLLUP, input),
ProvingRequestType.BLOCK_MERGE_ROLLUP,
input,
epochNumber,
signal,
);
}

getBlockRootRollupProof(
input: BlockRootRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<BlockRootOrBlockMergePublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.BLOCK_ROOT_ROLLUP, input),
ProvingRequestType.BLOCK_ROOT_ROLLUP,
input,
epochNumber,
signal,
);
}

getEmptyBlockRootRollupProof(
input: EmptyBlockRootRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<BlockRootOrBlockMergePublicInputs>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.EMPTY_BLOCK_ROOT_ROLLUP, input),
ProvingRequestType.EMPTY_BLOCK_ROOT_ROLLUP,
input,
epochNumber,
signal,
);
}

getEmptyPrivateKernelProof(
inputs: PrivateKernelEmptyInputData,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<KernelCircuitPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.PRIVATE_KERNEL_EMPTY, inputs),
ProvingRequestType.PRIVATE_KERNEL_EMPTY,
inputs,
epochNumber,
signal,
);
}

getMergeRollupProof(
input: MergeRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<BaseOrMergeRollupPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.MERGE_ROLLUP, input),
ProvingRequestType.MERGE_ROLLUP,
input,
epochNumber,
signal,
);
}
getPrivateBaseRollupProof(
baseRollupInput: PrivateBaseRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<BaseOrMergeRollupPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.PRIVATE_BASE_ROLLUP, baseRollupInput),
ProvingRequestType.PRIVATE_BASE_ROLLUP,
baseRollupInput,
epochNumber,
signal,
);
}

getPublicBaseRollupProof(
inputs: PublicBaseRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<BaseOrMergeRollupPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.PUBLIC_BASE_ROLLUP, inputs),
ProvingRequestType.PUBLIC_BASE_ROLLUP,
inputs,
epochNumber,
signal,
);
}

getRootParityProof(
inputs: RootParityInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof NESTED_RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.ROOT_PARITY, inputs),
ProvingRequestType.ROOT_PARITY,
inputs,
epochNumber,
signal,
);
}

getRootRollupProof(
input: RootRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<RootRollupPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.ROOT_ROLLUP, input),
ProvingRequestType.ROOT_ROLLUP,
input,
epochNumber,
signal,
);
}

getTubeProof(
tubeInput: TubeInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<ProofAndVerificationKey<typeof TUBE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.TUBE_PROOF, tubeInput),
ProvingRequestType.TUBE_PROOF,
tubeInput,
epochNumber,
signal,
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ describe('ProvingAgent', () => {
const inputs: ProvingJobInputs = { type: ProvingRequestType.BASE_PARITY, inputs: makeBaseParityInputs() };
const job: ProvingJob = {
id: randomBytes(8).toString('hex') as ProvingJobId,
blockNumber: 1,
epochNumber: 1,
type: ProvingRequestType.BASE_PARITY,
inputsUri: randomBytes(8).toString('hex') as ProofUri,
};
Expand Down
Loading
Loading