diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index 4e2bfa9d89924..3691b5ecb602b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -48,11 +48,18 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => const shape = Array.from(inputs[1].getBigInt64Array(), Number); const outputShape: number[] = calculateOutputShape(inputShape, shape); const dataType = inputs[0].dataType; - const components = dataType === DataType.bool ? 4 : 1; + const isBoolOrScalar = dataType === DataType.bool || ShapeUtil.size(inputShape) === 1; + const iComponents = + dataType === DataType.bool ? 4 : inputShape.length > 0 && inputShape[inputShape.length - 1] % 4 === 0 ? 4 : 1; + const components = isBoolOrScalar + ? 4 + : outputShape.length > 0 && outputShape[outputShape.length - 1] % 4 === 0 + ? 4 + : 1; const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components); const getShaderSource = (shaderHelper: ShaderHelper) => { - const input = inputVariable('input', dataType, inputShape.length, components); + const input = inputVariable('input', dataType, inputShape.length, iComponents); const output = outputVariable('output', dataType, outputShape.length, components); let assignment: string; if (dataType === DataType.bool) { @@ -74,9 +81,10 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => }`; } else { assignment = ` - let outputIndices = ${output.offsetToIndices('global_idx')}; + let outputIndices = ${output.offsetToIndices(`global_idx * ${components}`)}; let inputOffset = ${input.broadcastedIndicesToOffset('outputIndices', output)}; - ${output.setByOffset('global_idx', input.getByOffset('inputOffset'))} + let data = ${output.type.value}(${input.getByOffset(`inputOffset / ${iComponents}`)}); + ${output.setByOffset('global_idx', 'data')} }`; } return ` @@ -92,7 +100,7 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => ]; return { name: 'Expand', - shaderCache: { hint: `${outputShape.length}`, inputDependencies: ['rank'] }, + shaderCache: { hint: `${outputShape.length};${iComponents}${components}`, inputDependencies: ['rank'] }, getShaderSource, getRunData: () => ({ outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], diff --git a/js/web/test/data/ops/expand.jsonc b/js/web/test/data/ops/expand.jsonc index 613b4507b2b15..8fbe9339feb9b 100644 --- a/js/web/test/data/ops/expand.jsonc +++ b/js/web/test/data/ops/expand.jsonc @@ -134,6 +134,56 @@ "type": "float32" } ] + }, + { + "name": "Expand in components = 1, out components = 4", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [3, 2, 1], + "type": "float32" + }, + { + "data": [3, 1, 8], + "dims": [3], + "type": "int64" + } + ], + "outputs": [ + { + "data": [ + 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, + 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6 + ], + "dims": [3, 2, 8], + "type": "float32" + } + ] + }, + { + "name": "Expand in components = 4, out components = 4", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 1, 2, 8], + "type": "float32" + }, + { + "data": [2, 1, 8], + "dims": [3], + "type": "int64" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16 + ], + "dims": [1, 2, 2, 8], + "type": "float32" + } + ] } ] },