Skip to content

Commit

Permalink
[js/webgpu] Fix NAN caused by un-initialized buffer in instance-norm (#…
Browse files Browse the repository at this point in the history
…19387)

The added case will be NAN because of the un-initialized buffer.
  • Loading branch information
axinging authored Mar 19, 2024
1 parent 6bb6468 commit 4c6a6a3
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 1 deletion.
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ const computeMean =
let offset = currentImageNumber * uniforms.image_size;
var sum = ${fillVector('f32', components)};
var squaredSum = ${fillVector('f32', components)};
for (var i: u32 = 0; i < ${WG}; i++) {
for (var i: u32 = 0; i < min(${WG}, uniforms.H); i++) {
let value = input[offset + i + currentChannelNumber * ${WG}];
sum += value[0];
squaredSum += value[1];
Expand Down
80 changes: 80 additions & 0 deletions js/web/test/data/ops/instance-norm.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -224,5 +224,85 @@
]
}
]
},
{
"name": "Simple test with NHWC, components 1, buffer reuse",
"operator": "InstanceNormalization",
"inputShapeDefinitions": "rankOnly",
"opset": {
"domain": "",
"version": 17
},
"cases": [
{
"name": "Simple test",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6],
"dims": [2, 3, 1, 1],
"type": "float32"
},
{
"data": [1, 2, 3],
"dims": [3],
"type": "float32"
},
{
"data": [4, 5, 6],
"dims": [3],
"type": "float32"
}
],
"outputs": [
{
"data": [4, 5, 6, 4, 5, 6],
"dims": [2, 3, 1, 1],
"type": "float32"
}
]
}
]
},
{
"name": "Simple test with NHWC, components 2, buffer reuse",
"operator": "InstanceNormalization",
"inputShapeDefinitions": "rankOnly",
"opset": {
"domain": "",
"version": 17
},
"cases": [
{
"name": "Simple test",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4, 3, 2],
"dims": [1, 6, 1, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6],
"dims": [6],
"type": "float32"
},
{
"data": [4, 5, 6, 7, 8, 9],
"dims": [6],
"type": "float32"
}
],
"outputs": [
{
"data": [
2.775264263153076, 4, 5.224735260009766, 2.5505285263061523, 5, 7.449470520019531, 2.325794219970703, 6,
9.674205780029297, 11.898944854736328, 7, 2.1010589599609375, 14.123676300048828, 8, 1.876321792602539,
16.348413467407227, 9, 1.6515865325927734
],
"dims": [1, 6, 1, 3],
"type": "float32"
}
]
}
]
}
]

0 comments on commit 4c6a6a3

Please sign in to comment.