-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
webgpu: Support any component buffer #7426
Conversation
092cc9f
to
bffb703
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In future PRs, we may need to 1) clean up isVec4 and everything is component 2) change variable to input
@@ -24,6 +24,7 @@ import {computeDispatch, flatDispatchLayout} from './webgpu_util'; | |||
export class BinaryOpProgram implements WebGPUProgram { | |||
dispatch: [number, number, number]; | |||
dispatchLayout: {x: number[]}; | |||
outputComponent = 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have to init it as 1 here
@@ -32,6 +32,7 @@ export class Conv2DDerInputProgram implements WebGPUProgram { | |||
size = false; | |||
isVec4 = false; | |||
workPerThread = 1; | |||
outputComponent = 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have to init it as 1 here
@@ -36,6 +36,7 @@ export class MatMulSplitKProgram implements WebGPUProgram { | |||
transposeB: boolean; | |||
atomic = true; | |||
isVec4 = false; | |||
outputComponent = 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have to init it as 1 here
@@ -67,6 +68,21 @@ export const compileProgram = | |||
return pipeline; | |||
}; | |||
|
|||
export const typeSnippet = (component: number, type = 'f') => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer passing f32, i32, etc. as type. We may extend this to f16?
return 'vec4<i32>'; | ||
default: | ||
throw new Error(`${component}-component is not supported.`); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if and else can be merged to simplify the impl.
@@ -159,7 +159,7 @@ export class Conv2DMMProgram implements WebGPUProgram { | |||
dispatchLayout: {x: number[], y: number[], z: number[]}; | |||
dispatch: [number, number, number]; | |||
variableNames = ['x', 'W']; | |||
variableTypes: string[]; | |||
variableComponents: number[]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In future, can we change this to inputNames and inputComponents?
@@ -176,6 +176,7 @@ export class Conv2DMMProgram implements WebGPUProgram { | |||
tileInner: number; | |||
innerElementSize: number; | |||
isVec4?: boolean; | |||
outputComponent = 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have to init it as 1 here
@@ -143,7 +144,8 @@ export class Conv2DDerInputMMProgram implements WebGPUProgram { | |||
this.elementsPerThread); | |||
|
|||
if (this.isVec4) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may need to clean up isVec4 here.
program.isVec4, | ||
program.variableComponents ? | ||
program.variableComponents[i] : | ||
program.outputComponent ? program.outputComponent : 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just pass program.outputComponent here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, I didn't finish all the detailed reviews so I may have some other comments after you address the comments above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM and we may do more refactorings in future.
@@ -158,6 +174,8 @@ function makeShader( | |||
const prefixSnippets: string[] = []; | |||
const flatWorkgroupSize = program.workgroupSize[0] * | |||
program.workgroupSize[1] * program.workgroupSize[2]; | |||
program.outputComponent = | |||
program.outputComponent ? program.outputComponent : 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not good to set a class member outside. We can leave this for future refactoring.
@@ -67,6 +68,21 @@ export const compileProgram = | |||
return pipeline; | |||
}; | |||
|
|||
export const typeSnippet = (component: number, type = 'f32') => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter orders of typeSnippet and dataTypeToGPUType are different. We may also refactor this part in future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks.
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is![Reviewable](https://camo.githubusercontent.com/1541c4039185914e83657d3683ec25920c672c6c5c7ab4240ee7bff601adec0b/68747470733a2f2f72657669657761626c652e696f2f7265766965775f627574746f6e2e737667)