Skip to content

Commit

Permalink
[webgpu] Enable parallel compilation (#7191)
Browse files Browse the repository at this point in the history
FEATURE

* [webgpu] Add parallel compile

Demo code:
  // Parallel compile.
  tf.env().set('ENGINE_COMPILE_ONLY', true);
  const result1 = predict(model);
  await tf.backend().checkCompileCompletion();
  tf.dispose(result1);
  // Actual inference.
  tf.env().set('ENGINE_COMPILE_ONLY', false);
  const result2 = predict(model);
  await result2.data();
  tf.dispose(result2);

* Rename flag to WEBGPU_ENGINE_COMPILE_ONLY
  • Loading branch information
axinging authored May 5, 2023
1 parent bea6729 commit c9b746f
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 63 deletions.
129 changes: 78 additions & 51 deletions tfjs-backend-webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ export class WebGPUBackend extends KernelBackend {
private dummyContext: GPUCanvasContext;
private tensorDataPendingDisposal: DataId[] = [];
private static nextDataId = 0;
private pipelineCache: {[key: string]: GPUComputePipeline};
private pipelineCache:
{[key: string]: GPUComputePipeline|Promise<GPUComputePipeline>};
private programTimersStack: TimerNode[];
private querySet: GPUQuerySet;
private stagingPendingDisposal: BufferInfo[] = [];
Expand Down Expand Up @@ -356,8 +357,27 @@ export class WebGPUBackend extends KernelBackend {
return this.currentComputePass;
}

// Check if parallel compilation is done.
async checkCompileCompletionAsync() {
let pipelines: GPUComputePipeline[];
try {
pipelines = await Promise.all(Object.values(this.pipelineCache));
} catch (e) {
// TODO: Add test case to catch this exception.
throw new Error(e.message);
}
Object.keys(this.pipelineCache).map((key, i) => {
this.pipelineCache[key] = pipelines[i];
});
}

public async getBufferData(buffer: GPUBuffer, size: number):
Promise<ArrayBuffer> {
if (env().getBool('WEBGPU_ENGINE_COMPILE_ONLY')) {
console.warn(
'The data may be invalid since WEBGPU_ENGINE_COMPILE_ONLY is true, this can only be called when WEBGPU_ENGINE_COMPILE_ONLY is false');
return null;
}
const staging = this.bufferManager.acquireBuffer(
size, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ);
this.ensureCommandEncoderReady();
Expand Down Expand Up @@ -888,6 +908,47 @@ export class WebGPUBackend extends KernelBackend {
this.uploadToGPU(output.dataId);
program.dispatch = reshapeDispatch(this.device, program);

const inputsData = inputs.map((input: TensorInfo, i: number) => {
if (input.dtype === 'complex64') {
throw new Error(
`GPGPUProgram does not support complex64 input. For complex64 ` +
`dtypes, please separate the program into real and imaginary ` +
`parts.`);
}
this.uploadToGPU(input.dataId);

return {
// Returning dtype from tensorMap because it reflects dtype
// of underlying buffer, rather than abstract dtype.
dtype: this.tensorMap.get(input.dataId).dtype,
shape: input.shape,
name: program.variableNames[i]
};
});

program.shaderKey =
webgpu_program.makeShaderKey(program, inputsData, output);

const parallelCompilation = env().getBool('WEBGPU_ENGINE_COMPILE_ONLY');
if (!(program.shaderKey in this.pipelineCache)) {
this.pipelineCache[program.shaderKey] = webgpu_program.compileProgram(
this.device, program, inputsData, output, parallelCompilation);
}
program.pipeline = this.pipelineCache[program.shaderKey];

if (!parallelCompilation) {
this.recordAndSubmit(program, output, inputs, programDefinedUniform);
}
return output;
}

private recordAndSubmit(
program: webgpu_program.WebGPUProgram, output: TensorInfo,
inputs: TensorInfo[], programDefinedUniform?: ProgramUniform) {
if (program.pipeline instanceof Promise<GPUComputePipeline>) {
throw new Error(
'Please call checkCompileCompletionAsync to ensure parallel compilation is done!');
}
// There are six kinds of uniforms: NAN, INFINITY, shapes, shape strides,
// program size, program defined uniforms.
let programUniform: ProgramUniform = [];
Expand All @@ -912,36 +973,6 @@ export class WebGPUBackend extends KernelBackend {
}
}

const inputsData = inputs.map((input: TensorInfo, i: number) => {
if (input.dtype === 'complex64') {
throw new Error(
`GPGPUProgram does not support complex64 input. For complex64 ` +
`dtypes, please separate the program into real and imaginary ` +
`parts.`);
}
this.uploadToGPU(input.dataId);

return {
// Returning dtype from tensorMap because it reflects dtype
// of underlying buffer, rather than abstract dtype.
dtype: this.tensorMap.get(input.dataId).dtype,
shape: input.shape,
name: program.variableNames[i]
};
});

const shaderKey =
webgpu_program.makeShaderKey(program, bufferShapes, inputsData, output);

let pipeline;
if (shaderKey in this.pipelineCache) {
pipeline = this.pipelineCache[shaderKey];
} else {
pipeline = webgpu_program.compileProgram(
this.device, program, inputsData, output, shaderKey);
this.pipelineCache[shaderKey] = pipeline;
}

if (programDefinedUniform) {
programUniform = [...programUniform, ...programDefinedUniform];
}
Expand All @@ -950,49 +981,45 @@ export class WebGPUBackend extends KernelBackend {
this.makeUniforms(programUniform)
];

inputs.forEach(input => {
this.commandQueueOwnedIds.add(input.dataId);
});
this.commandQueueOwnedIds.add(output.dataId);

const bindGroup = this.device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
layout: program.pipeline.getBindGroupLayout(0),
entries: bindings.map((b, i) => ({binding: i, resource: b})),
});

this.ensureCommandEncoderReady();
const pass = this.getComputePass();

const shouldTimeProgram = this.activeTimers != null;
if (shouldTimeProgram) {
if (this.supportTimeQuery) {
// tslint:disable-next-line:no-any
(pass as any).writeTimestamp(this.querySet, 0);
}
if (shouldTimeProgram && this.supportTimeQuery) {
// tslint:disable-next-line:no-any
(pass as any).writeTimestamp(this.querySet, 0);
}
pass.setPipeline(pipeline);

pass.setPipeline(program.pipeline);
pass.setBindGroup(0, bindGroup);
pass.dispatchWorkgroups(
program.dispatch[0], program.dispatch[1], program.dispatch[2]);
if (shouldTimeProgram) {
if (this.supportTimeQuery) {
// tslint:disable-next-line:no-any
(pass as any).writeTimestamp(this.querySet, 1);
}

if (shouldTimeProgram && this.supportTimeQuery) {
// tslint:disable-next-line:no-any
(pass as any).writeTimestamp(this.querySet, 1);
}
this.dispatchNumberInEncoder++;

inputs.forEach(input => {
this.commandQueueOwnedIds.add(input.dataId);
});
this.commandQueueOwnedIds.add(output.dataId);

if (env().get('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE') as
number <= this.dispatchNumberInEncoder) {
this.submitQueue();
}

if (shouldTimeProgram) {
this.activeTimers.push({
name: program.constructor.name,
query: this.getQueryTime(this.querySet)
});
}
return output;
}

async getTimeFromQuerySet(querySet: GPUQuerySet) {
Expand Down
133 changes: 133 additions & 0 deletions tfjs-backend-webgpu/src/backend_webgpu_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,139 @@ describeWebGPU('keeping data on gpu ', () => {
});
});

async function parallelCompilationCommon(webGPUBackend: WebGPUBackend) {
const startNumBytes = (tf.memory() as WebGPUMemoryInfo).numBytesInGPU;
const startTensor = tf.memory().numTensors;
const startDataBuckets = webGPUBackend.numDataIds();

const a1 = tf.tensor1d([1, 1, 1]);
const b1 = tf.tensor1d([1, 1, 1]);

// Parallel compile.
tf.env().set('WEBGPU_ENGINE_COMPILE_ONLY', true);
const c1 = tf.add(a1, b1);
await webGPUBackend.checkCompileCompletionAsync();

// Actual inference.
tf.env().set('WEBGPU_ENGINE_COMPILE_ONLY', false);
const c2 = tf.add(a1, b1);
expectArraysEqual(await c2.data(), [2, 2, 2]);

tf.dispose([a1, b1, c1, c2]);
const endNumBytes = (tf.memory() as WebGPUMemoryInfo).numBytesInGPU;
const endTensor = tf.memory().numTensors;
const endDataBuckets = webGPUBackend.numDataIds();

// We only check numBytesInGPU. For parallel compilation,
// numBytesInGPUAllocated will be more because of the two pass
// uploadToGPU, but they will all be freed, resulting in endNumbytes equal
// to startNumBytes.
expect(startNumBytes).toEqual(endNumBytes);
expect(startTensor).toEqual(endTensor);
expect(endDataBuckets).toEqual(startDataBuckets);
}

describeWebGPU('parallel compilation', () => {
let prevBackend: string;
let savedWebGPUCPUForward: boolean;
let savedEngineCompileOnly: boolean;
let webGPUBackend: WebGPUBackend;
const customWebGPUBackendName = 'test-parallel';

beforeAll(() => {
prevBackend = tf.getBackend();
});

beforeEach(async () => {
const adapter = await navigator.gpu.requestAdapter({});
const device = await adapter.requestDevice({});
webGPUBackend = new WebGPUBackend(device);

tf.copyRegisteredKernels('webgpu', customWebGPUBackendName);
tf.registerBackend(customWebGPUBackendName, () => webGPUBackend);
tf.setBackend('test-parallel');

savedWebGPUCPUForward = tf.env().get('WEBGPU_CPU_FORWARD') as boolean;
savedEngineCompileOnly =
tf.env().get('WEBGPU_ENGINE_COMPILE_ONLY') as boolean;
tf.env().set('WEBGPU_CPU_FORWARD', false);
await tf.ready();
});

afterEach(() => {
tf.env().set('WEBGPU_CPU_FORWARD', savedWebGPUCPUForward);
tf.env().set('WEBGPU_ENGINE_COMPILE_ONLY', savedEngineCompileOnly);
tf.setBackend(prevBackend);
tf.removeBackend(customWebGPUBackendName);
});

it('should work if pipeline cache not exist.', async () => {
await parallelCompilationCommon(webGPUBackend);
});

it('should work if pipeline cache exists.', async () => {
// This will create pipeline cache.
const a0 = tf.tensor1d([1, 1, 1]);
const b0 = tf.tensor1d([1, 1, 1]);
const c0 = tf.add(a0, b0);
const data = await c0.data();
expectArraysClose(data, [2, 2, 2]);

await parallelCompilationCommon(webGPUBackend);
});

it('should work when running parallel compile again', async () => {
// This will create pipeline cache.
const a0 = tf.tensor1d([1, 1, 1]);
const b0 = tf.tensor1d([1, 1, 1]);
const c0 = tf.add(a0, b0);
const data = await c0.data();
expectArraysClose(data, [2, 2, 2]);

await parallelCompilationCommon(webGPUBackend);
await parallelCompilationCommon(webGPUBackend);
});

it('should not work if not call checkCompileCompletionAsync', async () => {
const a1 = tf.tensor1d([1, 1, 1]);
const b1 = tf.tensor1d([1, 1, 1]);

// Parallel compile but not call await (tf.backend() as
// WebGPUBackend).checkCompileCompletionAsync().
tf.env().set('WEBGPU_ENGINE_COMPILE_ONLY', true);
tf.add(a1, b1);

// Actual inference.
tf.env().set('WEBGPU_ENGINE_COMPILE_ONLY', false);
expect(() => tf.add(a1, b1))
.toThrowError(
'Please call checkCompileCompletionAsync to ensure parallel compilation is done!');
});

it('read data is invalid if parallel compilation is true', async () => {
const a1 = tf.tensor1d([1, 1, 1]);
const b1 = tf.tensor1d([1, 1, 1]);

// Parallel compile.
tf.env().set('WEBGPU_ENGINE_COMPILE_ONLY', true);
const c1 = tf.add(a1, b1);
await (tf.backend() as WebGPUBackend).checkCompileCompletionAsync();
// Read data is invalid.
expectArraysClose((await c1.data()).length, 0);
});

it('checkCompileCompletionAsync is nop if parallel compilation is false',
async () => {
const a1 = tf.tensor1d([1, 1, 1]);
const b1 = tf.tensor1d([1, 1, 1]);
// If parallel compilation is false, checkCompileCompletionAsync is nop.
tf.env().set('WEBGPU_ENGINE_COMPILE_ONLY', false);
const c1 = tf.add(a1, b1);
await (tf.backend() as WebGPUBackend).checkCompileCompletionAsync();
expectArraysClose(await c1.data(), [2, 2, 2]);
});
});

function createStagingGPUBufferFromData(
device: GPUDevice, data: number[], dtype: tf.DataType) {
const bytesPerElement = 4;
Expand Down
3 changes: 3 additions & 0 deletions tfjs-backend-webgpu/src/flags_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,6 @@ ENV.registerFlag('WEBGPU_CONV_SEPARATE_IM2COL_SHADER', () => false);
* etc.). 'unary,conv2d' to print both unary and conv2d.
*/
ENV.registerFlag('WEBGPU_PRINT_SHADER', () => '');

/** Experimental flag, whether enter compile only phase. */
ENV.registerFlag('WEBGPU_ENGINE_COMPILE_ONLY', () => false);
Loading

0 comments on commit c9b746f

Please sign in to comment.