-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
first operator (correctness validated)
- Loading branch information
Showing
11 changed files
with
438 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
import {Guid} from 'guid-typescript'; | ||
import {sizeof, Tensor} from '../../tensor'; | ||
import {ShapeUtil} from '../../util'; | ||
import {GpuData, GpuDataId, GpuDataType} from './types'; | ||
|
||
/** | ||
* manages GpuDataId -> GpuBuffer | ||
*/ | ||
export interface GpuDataManager { | ||
uploadData(tensor: Tensor, gpuDataType: GpuDataType): GpuData; | ||
createData(type: Tensor.DataType, dims: readonly number[], gpuDataType: GpuDataType): GpuData; | ||
releaseData(tensorId: Tensor.Id): void; | ||
downloadData(tensorId: Tensor.Id): Promise<ArrayBufferLike>; | ||
} | ||
|
||
interface DefaultCacheValue { | ||
gpuData: GpuData; | ||
size: number; | ||
} | ||
|
||
interface DownloadCacheValue { | ||
gpuData: GpuData; | ||
data: Promise<ArrayBufferLike>; | ||
} | ||
|
||
class GpuDataManagerImpl implements GpuDataManager { | ||
defaultCache: Map<GpuDataId, DefaultCacheValue>; | ||
downloadCache: Map<GpuDataId, DownloadCacheValue>; | ||
constructor(private device: GPUDevice) { | ||
this.defaultCache = new Map(); | ||
this.downloadCache = new Map(); | ||
} | ||
|
||
uploadData(tensor: Tensor, gpuDataType: GpuDataType): GpuData { | ||
if (gpuDataType !== GpuDataType.default) { | ||
throw new Error('we only support default GPU data type now'); | ||
} | ||
|
||
const cachedData = this.defaultCache.get(tensor.dataId); | ||
if (cachedData) { | ||
return cachedData.gpuData; | ||
} | ||
|
||
const src = tensor.numberData; | ||
const srcArrayBuffer = src.buffer; | ||
const srcOffset = src.byteOffset; | ||
const srcLength = src.byteLength; | ||
|
||
// create gpu buffer | ||
const gpuBuffer = | ||
this.device.createBuffer({mappedAtCreation: true, size: srcLength, usage: GPUBufferUsage.STORAGE}); | ||
|
||
// copy (upload) data | ||
const arrayBuffer = gpuBuffer.getMappedRange(); | ||
new Uint8Array(arrayBuffer).set(new Uint8Array(srcArrayBuffer, srcOffset, srcLength)); | ||
gpuBuffer.unmap(); | ||
|
||
const gpuData = {id: tensor.dataId, type: GpuDataType.default, buffer: gpuBuffer}; | ||
this.defaultCache.set(gpuData.id, {gpuData, size: srcLength}); | ||
return gpuData; | ||
} | ||
|
||
createData(type: Tensor.DataType, dims: readonly number[], gpuDataType: GpuDataType): GpuData { | ||
if (gpuDataType !== GpuDataType.default) { | ||
throw new Error('we only support default GPU data type now'); | ||
} | ||
|
||
// !!! | ||
// !!! IMPORTANT: TODO: whether we should keep the storage buffer every time, or always create new ones. | ||
// !!! This need to be figured out by performance test results. | ||
// !!! | ||
|
||
const elemCount = ShapeUtil.size(dims); | ||
const bufferLength = sizeof(type) * elemCount; | ||
|
||
// create gpu buffer | ||
const gpuBuffer = | ||
// eslint-disable-next-line no-bitwise | ||
this.device.createBuffer({size: bufferLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC}); | ||
|
||
const gpuData = {id: Guid.create(), type: GpuDataType.default, buffer: gpuBuffer}; | ||
this.defaultCache.set(gpuData.id, {gpuData, size: bufferLength}); | ||
return gpuData; | ||
} | ||
|
||
releaseData(tensorId: Tensor.Id): void { | ||
const cachedData = this.defaultCache.get(tensorId); | ||
if (!cachedData) { | ||
throw new Error('releasing data does not exist'); | ||
} | ||
|
||
this.defaultCache.delete(tensorId); | ||
cachedData.gpuData.buffer.destroy(); | ||
} | ||
|
||
async downloadData(tensorId: Tensor.Id): Promise<ArrayBufferLike> { | ||
const downloadData = this.downloadCache.get(tensorId); | ||
if (downloadData) { | ||
return downloadData.data; | ||
} | ||
|
||
const cachedData = this.defaultCache.get(tensorId); | ||
if (!cachedData) { | ||
throw new Error('data does not exist'); | ||
} | ||
|
||
const commandEncoder = this.device.createCommandEncoder(); | ||
const gpuReadBuffer = | ||
// eslint-disable-next-line no-bitwise | ||
this.device.createBuffer({size: cachedData.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ}); | ||
commandEncoder.copyBufferToBuffer( | ||
cachedData.gpuData.buffer /* source buffer */, 0 /* source offset */, gpuReadBuffer /* destination buffer */, | ||
0 /* destination offset */, cachedData.size /* size */ | ||
); | ||
const gpuCommands = commandEncoder.finish(); | ||
this.device.queue.submit([gpuCommands]); | ||
|
||
await gpuReadBuffer.mapAsync(GPUMapMode.READ); | ||
return gpuReadBuffer.getMappedRange(); | ||
} | ||
} | ||
|
||
export const createGpuDataManager = (device: GPUDevice): GpuDataManager => new GpuDataManagerImpl(device); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
import {Tensor} from '../../../tensor'; | ||
import {WebGpuInferenceHandler} from '../inference-handler'; | ||
import {GpuDataType} from '../types'; | ||
|
||
export const abs = (handler: WebGpuInferenceHandler, inputs: Tensor[]): Tensor[] => handler.run( | ||
{ | ||
name: 'Abs', | ||
inputTypes: [GpuDataType.default], | ||
// inputLayouts: [], | ||
// outputLayouts: [], | ||
shaderSource: ` | ||
@group(0) @binding(0) var<storage, read> inputData : array<f32>; | ||
@group(0) @binding(1) var<storage, write> outputData : array<f32>; | ||
@stage(compute) @workgroup_size(32) | ||
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) { | ||
// Guard against out-of-bounds work group sizes | ||
if (global_id.x * 32u >= ${inputs[0].size}u) { | ||
return; | ||
} | ||
// | ||
// TODO: SIMD? | ||
// | ||
let start = global_id.x * 32u; | ||
let end = select(start + 32u, ${inputs[0].size}u, start + 32u > ${inputs[0].size}u); | ||
for (var i = start; i < end; i = i + 1u) { | ||
outputData[i] = abs(inputData[i]); | ||
} | ||
}`, | ||
outputs: [{dims: inputs[0].dims, type: inputs[0].type, gpuDataType: GpuDataType.default}], | ||
// entryPoint: 'main', | ||
dispatchGroup: (inputTensors) => ({x: Math.ceil(inputTensors[0].size / 32)}) | ||
}, | ||
inputs); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
import {Profiler} from '../../instrument'; | ||
|
||
import {Artifact, GpuData, ProgramInfo} from './types'; | ||
|
||
/** | ||
* ProgramManager is the main class behind running computations | ||
* It builds ProgramInfo's into Artifacts | ||
* It compiles given ProgramInfo's into WebGL Prorams (cached as Artifacts) | ||
* Uses the artifact to run the computation by calling Draw on | ||
* the WebGL drawing buffer | ||
* ProgramManager automatically maps (binds) input variables to their | ||
* corresponding Location's in the binary program | ||
*/ | ||
export class ProgramManager { | ||
repo: Map<unknown, Artifact>; // this should be per-session object | ||
attributesBound: boolean; | ||
|
||
constructor(private device: GPUDevice, public profiler: Readonly<Profiler>) { | ||
this.repo = new Map(); | ||
this.attributesBound = false; | ||
} | ||
getArtifact(key: unknown): Artifact|undefined { | ||
return this.repo.get(key); | ||
} | ||
setArtifact(key: unknown, artifact: Artifact): void { | ||
this.repo.set(key, artifact); | ||
} | ||
run(buildArtifact: Artifact, inputs: GpuData[], outputs: GpuData[], | ||
dispatchGroup: {x: number; y?: number; z?: number}): void { | ||
const device = this.device; | ||
|
||
// TODO: should we create command encoder every time? | ||
|
||
const commandEncoder = device.createCommandEncoder(); | ||
|
||
const passEncoder = commandEncoder.beginComputePass(); | ||
passEncoder.setPipeline(buildArtifact.computePipeline); | ||
const entries = []; | ||
for (const input of inputs) { | ||
entries.push({binding: entries.length, resource: {buffer: input.buffer}}); | ||
} | ||
for (const output of outputs) { | ||
entries.push({binding: entries.length, resource: {buffer: output.buffer}}); | ||
} | ||
const bindGroup = device.createBindGroup({layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries}); | ||
passEncoder.setBindGroup(0, bindGroup); | ||
|
||
const {x, y, z} = dispatchGroup; | ||
passEncoder.dispatch(x, y, z); | ||
|
||
passEncoder.endPass(); | ||
|
||
device.queue.submit([commandEncoder.finish()]); | ||
} | ||
dispose(): void { | ||
// this.repo.forEach(a => this.glContext.deleteProgram(a.program)); | ||
} | ||
build(programInfo: ProgramInfo): Artifact { | ||
const device = this.device; | ||
|
||
const shaderModule = device.createShaderModule({code: programInfo.shaderSource}); | ||
|
||
const computePipeline = device.createComputePipeline({compute: {module: shaderModule, entryPoint: 'main'}}); | ||
|
||
return {programInfo, computePipeline}; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.