Skip to content

Commit

Permalink
[js/web] Fix NAN caused by buffer reuse in instance-norm
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Feb 29, 2024
1 parent 7455dd1 commit 3666cb4
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 0 deletions.
4 changes: 4 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ const computeMean =
];

const getMeanShaderSource = (shaderHelper: ShaderHelper) => {
const outputInitial = outputType === 'vec2f' ?
`${outputType}(0.0);` :
`${outputType}(vec${components}<f32>(0.0),vec${components}<f32>(0.0));`;
const inputHelper = inputVariable('input', input.dataType, input.dims, components);
return `
${shaderHelper.declareVariables(inputHelper)}
Expand All @@ -151,6 +154,7 @@ const computeMean =
let wgId = global_idx % ${WG};
let wgOffset = wgId * uniforms.wg_size;
if (wgOffset >= uniforms.H) {
output[global_idx] = ${outputInitial};
return;
}
let wgMax = min(wgOffset + uniforms.wg_size, uniforms.H);
Expand Down
82 changes: 82 additions & 0 deletions js/web/test/data/ops/instance-norm-reusebuffer.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
[
{
"name": "Simple test with NHWC, components 1",
"operator": "InstanceNormalization",
"inputShapeDefinitions": "rankOnly",
"opset": {
"domain": "",
"version": 17
},
"cases": [
{
"name": "Simple test",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6],
"dims": [2, 3, 1, 1],
"type": "float32"
},
{
"data": [1, 2, 3],
"dims": [3],
"type": "float32"
},
{
"data": [4, 5, 6],
"dims": [3],
"type": "float32"
}
],
"outputs": [
{
"data": [4, 5, 6, 4, 5, 6],
"dims": [2, 3, 1, 1],
"type": "float32"
}
]
}
]
},
{
"name": "Simple test with NHWC, components 2",
"operator": "InstanceNormalization",
"inputShapeDefinitions": "rankOnly",
"opset": {
"domain": "",
"version": 17
},
"cases": [
{
"name": "Simple test",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4, 3, 2],
"dims": [1, 6, 1, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6],
"dims": [6],
"type": "float32"
},
{
"data": [4, 5, 6, 7, 8, 9],
"dims": [6],
"type": "float32"
}
],
"outputs": [
{
"data": [
2.775264263153076, 4, 5.224735260009766, 2.5505285263061523, 5, 7.449470520019531, 2.325794219970703, 6,
9.674205780029297, 11.898944854736328, 7, 2.1010589599609375, 14.123676300048828, 8, 1.876321792602539,
16.348413467407227, 9, 1.6515865325927734
],
"dims": [1, 6, 1, 3],
"type": "float32"
}
]
}
]
}
]
80 changes: 80 additions & 0 deletions js/web/test/data/ops/instance-norm.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -224,5 +224,85 @@
]
}
]
},
{
"name": "Simple test with NHWC, components 1, buffer reuse",
"operator": "InstanceNormalization",
"inputShapeDefinitions": "rankOnly",
"opset": {
"domain": "",
"version": 17
},
"cases": [
{
"name": "Simple test",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6],
"dims": [2, 3, 1, 1],
"type": "float32"
},
{
"data": [1, 2, 3],
"dims": [3],
"type": "float32"
},
{
"data": [4, 5, 6],
"dims": [3],
"type": "float32"
}
],
"outputs": [
{
"data": [4, 5, 6, 4, 5, 6],
"dims": [2, 3, 1, 1],
"type": "float32"
}
]
}
]
},
{
"name": "Simple test with NHWC, components 2, buffer reuse",
"operator": "InstanceNormalization",
"inputShapeDefinitions": "rankOnly",
"opset": {
"domain": "",
"version": 17
},
"cases": [
{
"name": "Simple test",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4, 3, 2],
"dims": [1, 6, 1, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6],
"dims": [6],
"type": "float32"
},
{
"data": [4, 5, 6, 7, 8, 9],
"dims": [6],
"type": "float32"
}
],
"outputs": [
{
"data": [
2.775264263153076, 4, 5.224735260009766, 2.5505285263061523, 5, 7.449470520019531, 2.325794219970703, 6,
9.674205780029297, 11.898944854736328, 7, 2.1010589599609375, 14.123676300048828, 8, 1.876321792602539,
16.348413467407227, 9, 1.6515865325927734
],
"dims": [1, 6, 1, 3],
"type": "float32"
}
]
}
]
}
]

0 comments on commit 3666cb4

Please sign in to comment.