-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix avgPool3d #7133
Fix avgPool3d #7133
Changes from 14 commits
411bae3
385c1a6
010dfbc
3b62ff4
dbc15e6
8299c05
1822a83
061f929
e4afabf
b6e48a0
67a5f4a
02b63da
f70417c
f9036e3
9fb2bc5
5cba16b
e81b203
77bc49b
0bec23e
f37e707
58108f1
1b478d6
60c9379
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -342,7 +342,10 @@ export class Pool3DProgram implements GPGPUProgram { | |
let returnValue = `${poolType}(${poolType}(${poolType}(` + | ||
'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])'; | ||
if (poolType === 'avg') { | ||
returnValue = `avgValue / count`; | ||
// Use `max(count, 1.0)` instead of `count` in case count === 0.0. | ||
// If count === 0.0, `avgValue` is always 0.0 and we change `count`'s | ||
// value to avoid dividing zero. | ||
returnValue = `avgValue / max(count, 1.0)`; | ||
} | ||
|
||
const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4; | ||
|
@@ -448,8 +451,8 @@ export class Pool3DProgram implements GPGPUProgram { | |
${updateSnippet} | ||
} | ||
} | ||
setOutput(${returnValue}); | ||
} | ||
setOutput(${returnValue}); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
} | ||
`; | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -365,24 +365,23 @@ function computeOutputShape2D( | |
} | ||
|
||
function computeOutputShape4D( | ||
inShape: [number, number, number, number], fieldSize: number, | ||
outChannels: number, stride: number, zeroPad?: number, | ||
inShape: [number, number, number, number], | ||
filterShape: [number, number, number], outChannels: number, | ||
strides: [number, number, number], zeroPad?: number, | ||
roundingMode?: 'floor'|'round'|'ceil'): [number, number, number, number] { | ||
if (zeroPad == null) { | ||
zeroPad = computeDefaultPad(inShape, fieldSize, stride); | ||
zeroPad = computeDefaultPad(inShape, filterShape[0], strides[0]); | ||
} | ||
const inputDepth = inShape[0]; | ||
const inputRows = inShape[1]; | ||
const inputCols = inShape[2]; | ||
|
||
const outputDepths = | ||
round((inputDepth - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); | ||
const outputRows = | ||
round((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); | ||
const outputCols = | ||
round((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); | ||
|
||
return [outputDepths, outputRows, outputCols, outChannels]; | ||
const outShape: [number, number, number, number] = [0, 0, 0, outChannels]; | ||
for (let index = 0; index < 3; index++) { | ||
if (inShape[index] + 2 * zeroPad >= filterShape[index]) { | ||
outShape[index] = round( | ||
(inShape[index] - filterShape[index] + 2 * zeroPad) / strides[index] + | ||
1, | ||
roundingMode); | ||
} | ||
} | ||
return outShape; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adds two changes in
|
||
} | ||
|
||
export function computeDefaultPad( | ||
|
@@ -496,6 +495,10 @@ function get3DPadAndOutInfo( | |
let outHeight: number; | ||
let outWidth: number; | ||
|
||
if (pad === 'valid') { | ||
pad = 0; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If pad is 'valid', it should have the same result as By the way, the roundingMode for this case is supposed to be 'truncate', instead of 'ceil', referring to tensorflow. |
||
|
||
if (typeof pad === 'number') { | ||
const padType = (pad === 0) ? 'VALID' : 'NUMBER'; | ||
padInfo = { | ||
|
@@ -508,8 +511,9 @@ function get3DPadAndOutInfo( | |
type: padType | ||
}; | ||
const outShape = computeOutputShape4D( | ||
[inDepth, inHeight, inWidth, 1], filterDepth, 1, strideDepth, pad, | ||
roundingMode); | ||
[inDepth, inHeight, inWidth, 1], | ||
[filterDepth, filterHeight, filterWidth], 1, | ||
[strideDepth, strideHeight, strideWidth], pad, roundingMode); | ||
outDepth = outShape[0]; | ||
outHeight = outShape[1]; | ||
outWidth = outShape[2]; | ||
|
@@ -529,19 +533,6 @@ function get3DPadAndOutInfo( | |
const right = padAlongWidth - left; | ||
|
||
padInfo = {top, bottom, left, right, front, back, type: 'SAME'}; | ||
} else if (pad === 'valid') { | ||
padInfo = { | ||
top: 0, | ||
bottom: 0, | ||
left: 0, | ||
right: 0, | ||
front: 0, | ||
back: 0, | ||
type: 'VALID' | ||
}; | ||
outDepth = Math.ceil((inDepth - filterDepth + 1) / strideDepth); | ||
outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight); | ||
outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth); | ||
} else { | ||
throw Error(`Unknown padding parameter: ${pad}`); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @qjia7 @xhcao , WebGPU may have the same issue,
tfjs/tfjs-backend-webgpu/src/pool2d_webgpu.ts
Line 56 in 4314334
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Linchenn . Just curious, is L124
returnValue =
avgValue / count;
missed for change?And another question is that it seems that it's not possible to happen that
count
is zero intfjs/tfjs-backend-webgpu/src/pool2d_webgpu.ts
Line 56 in 4314334
updateSnippet
will increasecount
. Will there be a situation that the filter window is totally no overlap with the input window?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch, thanks!
If padding >= filter size, this would happen, as #7122.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for your explanation.
Can you help to cover the webgpu change in this PR since you already find the right place? :) And it will be great if you add a similar case as #7122 to file
avg_pool_3d_test.ts
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thank you!