diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/slice.ts b/js/web/lib/onnxjs/backends/webgpu/ops/slice.ts index fda8659993188..c5642c0921811 100644 --- a/js/web/lib/onnxjs/backends/webgpu/ops/slice.ts +++ b/js/web/lib/onnxjs/backends/webgpu/ops/slice.ts @@ -41,6 +41,29 @@ export const parseSliceAttributes: OperatorInitialization<SliceAttributes> = (no return createAttributeWithCacheKey({starts, ends, axes}); }; +const offsetToIndices = (offset: string, strides: readonly number[], indicesPrefix: string): string => { + const outputLines: string[] = []; + + for (let i = 0; i < strides.length - 1; i++) { + outputLines.push(`var ${indicesPrefix}${i}=${offset}/${strides[i]}u;`); + outputLines.push(`${offset}%=${strides[i]}u;`); + } + outputLines.push(`var ${indicesPrefix}${strides.length - 1}=${offset};`); + + return outputLines.join('\n'); +}; + +const indicesToOffset = (indicesPrefix: string, strides: readonly number[], offset: string): string => { + const outputLines: string[] = []; + + for (let i = 0; i < strides.length - 1; i++) { + outputLines.push(`${offset}+=${indicesPrefix}${i} * ${strides[i]}u;`); + } + outputLines.push(`${offset}+=${indicesPrefix}${strides.length - 1};`); + + return outputLines.join('\n'); +}; + const createSliceProgramInfo = (input: Tensor, attributes: SliceAttributes, dataType = 'f32'): ProgramInfo => { const axes = (attributes.axes.length === 0) ? input.dims.slice(0).map((val, i) => i) : attributes.axes; const normalizedAxes = ShapeUtil.normalizeAxes(axes, input.dims.length); @@ -59,12 +82,11 @@ const createSliceProgramInfo = (input: Tensor, attributes: SliceAttributes, data const outputShape = input.dims.slice(); - const sliceOps: Array<[number, number]> = []; + const sliceOps: string[] = []; for (let i = 0; i < normalizedAxes.length; i++) { outputShape[normalizedAxes[i]] = ends[i] - starts[i]; if (starts[i] > 0) { - // sliceOps.push(`outputIdx[${normalizedAxes[i]}] += ${starts[i]};`); - sliceOps.push([normalizedAxes[i], starts[i]]); + sliceOps.push(`idx_${normalizedAxes[i]} += ${starts[i]}u;`); } // else { sliceOps.push(`outputIdx[${normalizedAxes[i]}] += 0;`); } } @@ -84,8 +106,11 @@ const createSliceProgramInfo = (input: Tensor, attributes: SliceAttributes, data } var offset = global_id.x; - ${sliceOps.map(i => `offset += ${i[1]}u * ${outputStrides[i[0]]}u;`).join('')} - output[global_id.x] = input[offset]; + ${offsetToIndices('offset', outputStrides, 'idx_')} + ${sliceOps.join('')} + var offsetInput = 0u; + ${indicesToOffset('idx_', ShapeUtil.computeStrides(input.dims), 'offsetInput')} + output[global_id.x] = input[offsetInput]; }`; return { ...sliceProgramMetadata,