Skip to content

Commit

Permalink
upgrade to latest webgpu spec
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent 4ed1bfb commit 59b10fb
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 34 deletions.
6 changes: 3 additions & 3 deletions js/web/lib/onnxjs/backends/webgpu/ops/binary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,16 @@ const createBinaryOpProgramShader =
}

return `
let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
const WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
@group(0) @binding(0) var<storage, read> aData : array<vec4<${typeA}>>;
@group(0) @binding(1) var<storage, read> bData : array<vec4<${typeB}>>;
@group(0) @binding(2) var<storage, write> outputData : array<vec4<${typeOutput}>>;
@group(0) @binding(2) var<storage, read_write> outputData : array<vec4<${typeOutput}>>;
${additionalImplementation ?? ''}
${broadcastImpl}
@stage(compute) @workgroup_size(WORKGROUP_SIZE)
@compute @workgroup_size(WORKGROUP_SIZE)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
// Guard against out-of-bounds work group sizes
Expand Down
6 changes: 3 additions & 3 deletions js/web/lib/onnxjs/backends/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ const createConcatProgramInfo =

const indicesAxis = rank < 2 ? 'indices' : `indices[${axis}]`;
const shaderSource = `
let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
const WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
${inputStorageBuffersDeclarations.join('\n')}
@group(0) @binding(${inputs.length}) var<storage, write> output : array<${dataType}>;
@group(0) @binding(${inputs.length}) var<storage, read_write> output : array<${dataType}>;
${inputIndicesHelpers.map(i => i.i2oImpl).join('\n')}
${outputIndicesHelper.o2iImpl}
Expand All @@ -84,7 +84,7 @@ const createConcatProgramInfo =
${calculateInputIndexImpl(sizeInConcatAxis.length)}
${readBufferDataImpl(inputIndicesHelpers, rank, dataType)}
@stage(compute) @workgroup_size(WORKGROUP_SIZE)
@compute @workgroup_size(WORKGROUP_SIZE)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
// Guard against out-of-bounds work group sizes
Expand Down
10 changes: 5 additions & 5 deletions js/web/lib/onnxjs/backends/webgpu/ops/conv-grouped.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,19 @@ const createGroupedConvProgramInfo =
const wIndicesHelper = createIndicesHelper('w', wShape);

const shaderSource = `
let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
let strides: vec2<u32> = vec2(${attributes.strides[0]}u, ${attributes.strides[1]}u);
let pads: vec2<u32> = vec2(${attributes.pads[0]}u, ${attributes.pads[1]}u);
const WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
const strides: vec2<u32> = vec2(${attributes.strides[0]}u, ${attributes.strides[1]}u);
const pads: vec2<u32> = vec2(${attributes.pads[0]}u, ${attributes.pads[1]}u);
${inputStorageBuffersDeclarations.join('\n')}
@group(0) @binding(${inputStorageBuffersDeclarations.length}) var<storage, write> output : array<${dataType}>;
@group(0) @binding(${inputStorageBuffersDeclarations.length}) var<storage, read_write> output : array<${dataType}>;
${activationFunction}
${outputIndicesHelper.o2iImpl}
${xIndicesHelper.i2oImpl}
${wIndicesHelper.i2oImpl}
@stage(compute) @workgroup_size(WORKGROUP_SIZE)
@compute @workgroup_size(WORKGROUP_SIZE)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
// Guard against out-of-bounds work group sizes
if (global_id.x >= ${outputSize}u) {
Expand Down
6 changes: 3 additions & 3 deletions js/web/lib/onnxjs/backends/webgpu/ops/gather.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,17 @@ const createGatherProgramInfo =
const indicesIndicesHelper = createIndicesHelper('indices', indicesShape);

const shaderSource = `
let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
const WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
@group(0) @binding(0) var<storage, read> data : array<${dataType}>;
@group(0) @binding(1) var<storage, read> indices : array<i32>;
@group(0) @binding(2) var<storage, write> output : array<${dataType}>;
@group(0) @binding(2) var<storage, read_write> output : array<${dataType}>;
${outputIndicesHelper.o2iImpl}
${indicesIndicesHelper.i2oImpl}
${dataIndicesHelper.i2oImpl}
@stage(compute) @workgroup_size(WORKGROUP_SIZE)
@compute @workgroup_size(WORKGROUP_SIZE)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
// Guard against out-of-bounds work group sizes
Expand Down
16 changes: 8 additions & 8 deletions js/web/lib/onnxjs/backends/webgpu/ops/gemm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,17 @@ const createGemmProgramInfo =
inputStorageBuffersDeclarations.push(`@group(0) @binding(2) var<storage, read> c : array<${dataType}>;`);
}
const shaderSource = `
let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
let M: u32 = ${M}u;
let N: u32 = ${N}u;
let K: u32 = ${K}u;
let alpha = ${dataType}(${attributes.alpha});
let beta = ${dataType}(${attributes.beta});
const WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
const M: u32 = ${M}u;
const N: u32 = ${N}u;
const K: u32 = ${K}u;
const alpha = ${dataType}(${attributes.alpha});
const beta = ${dataType}(${attributes.beta});
${inputStorageBuffersDeclarations.join('\n')}
@group(0) @binding(${inputs.length}) var<storage, write> output : array<${dataType}>;
@group(0) @binding(${inputs.length}) var<storage, read_write> output : array<${dataType}>;
@stage(compute) @workgroup_size(WORKGROUP_SIZE)
@compute @workgroup_size(WORKGROUP_SIZE)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
// Guard against out-of-bounds work group sizes
Expand Down
12 changes: 6 additions & 6 deletions js/web/lib/onnxjs/backends/webgpu/ops/matmul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,18 @@ function createMatmulProgramInfo(
const K = aShape[aShape.length - 1];
const N = outputShape[outputShape.length - 1];
const shaderSource = `
let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
let M: u32 = ${M}u;
let N: u32 = ${N}u;
let K: u32 = ${K}u;
const WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
const M: u32 = ${M}u;
const N: u32 = ${N}u;
const K: u32 = ${K}u;
@group(0) @binding(0) var<storage, read> a : array<${dataType}>;
@group(0) @binding(1) var<storage, read> b : array<${dataType}>;
@group(0) @binding(2) var<storage, write> output : array<${dataType}>;
@group(0) @binding(2) var<storage, read_write> output : array<${dataType}>;
${activationFunction}
@stage(compute) @workgroup_size(WORKGROUP_SIZE)
@compute @workgroup_size(WORKGROUP_SIZE)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
// Guard against out-of-bounds work group sizes
Expand Down
6 changes: 3 additions & 3 deletions js/web/lib/onnxjs/backends/webgpu/ops/slice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ const createSliceProgramInfo = (input: Tensor, attributes: SliceAttributes, data
const outputSize = ShapeUtil.size(outputShape);
const outputStrides = ShapeUtil.computeStrides(outputShape);
const shaderSource = `
let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
const WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
@group(0) @binding(0) var<storage, read> input : array<${dataType}>;
@group(0) @binding(1) var<storage, write> output : array<${dataType}>;
@group(0) @binding(1) var<storage, read_write> output : array<${dataType}>;
@stage(compute) @workgroup_size(WORKGROUP_SIZE)
@compute @workgroup_size(WORKGROUP_SIZE)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
// Guard against out-of-bounds work group sizes
Expand Down
6 changes: 3 additions & 3 deletions js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ const createElementwiseProgramShader =
expression = funcCall('a');
}
return `
let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
const WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
@group(0) @binding(0) var<storage, read> inputData : array<vec4<f32>>;
@group(0) @binding(1) var<storage, write> outputData : array<vec4<f32>>;
@group(0) @binding(1) var<storage, read_write> outputData : array<vec4<f32>>;
${additionalImplementation ?? ''}
@stage(compute) @workgroup_size(WORKGROUP_SIZE)
@compute @workgroup_size(WORKGROUP_SIZE)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
// Guard against out-of-bounds work group sizes
Expand Down

0 comments on commit 59b10fb

Please sign in to comment.