Skip to content

Commit

Permalink
Use VAOs for save+restore of vertexAttribPointer state between differ…
Browse files Browse the repository at this point in the history
…ent webgl programs. (#6913)

BUG
* Use VAOs for webgl gpgpu Programs.

This prevents issues where vertex attrib layouts don't match between programs (which is possible even with the same vertex shader!), or where someone else overwrites vertexAttribPointer between createProgram and executeProgram.

* Support WebGL 1 via OES_vertex_array_object ext.

* Fixes.

Workaround typechecker instanceof constraint into lambda.
Don't save-and-restore VAO since we overwrite it anyway.
Bind index buffer to VAOs.

* Update gpgpu_context.ts

* Update gpgpu_context.ts

Co-authored-by: Ping Yu <[email protected]>
  • Loading branch information
kdashg and pyu10055 authored Oct 26, 2022
1 parent 396a56e commit e9a9d4f
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 20 deletions.
110 changes: 94 additions & 16 deletions tfjs-backend-webgl/src/gpgpu_context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ export interface FenceContext {
isFencePassed(): boolean;
}

type WebGLVao = WebGLVertexArrayObject | WebGLVertexArrayObjectOES;

export interface GPGPUContextProgram extends WebGLProgram {
vao: WebGLVao;
}

export class GPGPUContext {
gl: WebGLRenderingContext;
textureFloatExtension: {};
Expand All @@ -42,12 +48,17 @@ export class GPGPUContext {
indexBuffer: WebGLBuffer;
framebuffer: WebGLFramebuffer;
outputTexture: WebGLTexture|null = null;
program: WebGLProgram|null = null;
program: GPGPUContextProgram|null = null;
private disposed = false;
private disjoint: boolean;
private vertexShader: WebGLShader;
textureConfig: TextureConfig;

createVertexArray: () => WebGLVao | null;
bindVertexArray: (vao: WebGLVao | null) => void;
deleteVertexArray: (vao: WebGLVao | null) => void;
getVertexArray: () => WebGLVao | null;

constructor(gl?: WebGLRenderingContext) {
const glVersion = env().getNumber('WEBGL_VERSION');
if (gl != null) {
Expand All @@ -56,6 +67,51 @@ export class GPGPUContext {
} else {
this.gl = getWebGLContext(glVersion);
}
gl = this.gl;

if (env().getNumber('WEBGL_VERSION') === 2) {
const gl2 = gl as WebGL2RenderingContext;
this.createVertexArray = () => {
return webgl_util.callAndCheck(gl2,
() => gl2.createVertexArray());
};
this.bindVertexArray = (vao: WebGLVao|null) => {
return webgl_util.callAndCheck(gl2,
() => gl2.bindVertexArray(vao as WebGLVertexArrayObject));
};
this.deleteVertexArray = (vao: WebGLVao|null) => {
return webgl_util.callAndCheck(gl2,
() => gl2.deleteVertexArray(vao as WebGLVertexArrayObject));
};
this.getVertexArray = () => {
return webgl_util.callAndCheck(gl2,
() => gl2.getParameter(gl2.VERTEX_ARRAY_BINDING));
};
} else if (gl != null) {
const ext = gl.getExtension('OES_vertex_array_object');
if (ext == null) {
throw new Error(
'All WebGL1 implementations are expected to offer' +
' OES_vertex_array_object.');
}
this.createVertexArray = () => {
return webgl_util.callAndCheck(gl,
() => ext.createVertexArrayOES());
};
this.bindVertexArray = (vao: WebGLVao|null) => {
return webgl_util.callAndCheck(gl,
() => ext.bindVertexArrayOES(vao as WebGLVertexArrayObjectOES));
};
this.deleteVertexArray = (vao: WebGLVao|null) => {
return webgl_util.callAndCheck(gl,
() => ext.deleteVertexArrayOES(vao as WebGLVertexArrayObjectOES));
};
this.getVertexArray = () => {
return webgl_util.callAndCheck(gl,
() => gl.getParameter(ext.VERTEX_ARRAY_BINDING_OES));
};
}

// WebGL 2.0 enables texture floats without an extension.
let COLOR_BUFFER_FLOAT = 'WEBGL_color_buffer_float';
const COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
Expand Down Expand Up @@ -273,9 +329,7 @@ export class GPGPUContext {
this.gl, physicalRows, physicalCols));
}

private vertexAttrsAreBound = false;

public createProgram(fragmentShader: WebGLShader): WebGLProgram {
public createProgram(fragmentShader: WebGLShader): GPGPUContextProgram {
this.throwIfDisposed();
const gl = this.gl;
if (this.vertexShader == null) {
Expand All @@ -286,32 +340,52 @@ export class GPGPUContext {
gl, () => gl.attachShader(program, this.vertexShader));
webgl_util.callAndCheck(gl, () => gl.attachShader(program, fragmentShader));
webgl_util.linkProgram(gl, program);
if (this.debug) {
webgl_util.validateProgram(gl, program);
}
if (!this.vertexAttrsAreBound) {
this.setProgram(program);
this.vertexAttrsAreBound = gpgpu_util.bindVertexProgramAttributeStreams(
gl, this.program, this.vertexBuffer);

let program2: GPGPUContextProgram;
{
program2 = Object.assign(program, {
vao: this.createVertexArray(),
});
this.bindVertexArray(program2.vao);
// Bind index buffer, and vertex buffers based on program attrib
// locations.
webgl_util.callAndCheck(
gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, this.indexBuffer));
console.assert(
gpgpu_util.bindVertexProgramAttributeStreams(gl, program2,
this.vertexBuffer),
'gpgpu_util.bindVertexProgramAttributeStreams not fully successful.');

if (this.debug) {
webgl_util.validateProgram(gl, program2);
}
}
return program;
this.setProgram(program2);

return program2;
}

public deleteProgram(program: WebGLProgram) {
public deleteProgram(program: GPGPUContextProgram) {
this.throwIfDisposed();
if (program === this.program) {
this.program = null;
}
if (program != null) {
webgl_util.callAndCheck(this.gl, () => this.gl.deleteProgram(program));
this.deleteVertexArray(program.vao);
}
}

public setProgram(program: WebGLProgram|null) {
public setProgram(program: GPGPUContextProgram|null) {
this.throwIfDisposed();
this.program = program;
if ((this.program != null) && this.debug) {
webgl_util.validateProgram(this.gl, this.program);

if (this.program != null) {
this.bindVertexArray(this.program.vao);

if (this.debug) {
webgl_util.validateProgram(this.gl, this.program);
}
}
webgl_util.callAndCheck(this.gl, () => this.gl.useProgram(program));
}
Expand Down Expand Up @@ -389,6 +463,10 @@ export class GPGPUContext {
this.throwIfNoProgram();
const gl = this.gl;
if (this.debug) {
const boundVao = this.getVertexArray();
console.assert(boundVao === this.program.vao,
'VAO changed between setProgram and executeProgram!');

this.debugValidate();
}
webgl_util.callAndCheck(
Expand Down
4 changes: 2 additions & 2 deletions tfjs-backend-webgl/src/gpgpu_context_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import {expectArraysEqual} from '@tensorflow/tfjs-core/dist/test_util';
import {WEBGL_ENVS} from './backend_webgl_test_registry';
import * as canvas_util from './canvas_util';
import {getGlslDifferences} from './glsl_version';
import {GPGPUContext, linearSearchLastTrue} from './gpgpu_context';
import {GPGPUContext, GPGPUContextProgram, linearSearchLastTrue} from './gpgpu_context';
import * as tex_util from './tex_util';
import {Texture} from './tex_util';
import {createFragmentShader} from './webgl_util';
Expand Down Expand Up @@ -117,7 +117,7 @@ describeWithFlags(
describeWithFlags(
'GPGPUContext setOutputMatrixWriteRegion', DOWNLOAD_FLOAT_ENVS, () => {
let gpgpu: GPGPUContext;
let program: WebGLProgram;
let program: GPGPUContextProgram;
let output: WebGLTexture;

beforeEach(() => {
Expand Down
4 changes: 2 additions & 2 deletions tfjs-backend-webgl/src/gpgpu_math.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import {backend_util, env, Tensor, TypedArray, util} from '@tensorflow/tfjs-core';

import {GPGPUContext} from './gpgpu_context';
import {GPGPUContext, GPGPUContextProgram} from './gpgpu_context';
import * as shader_compiler from './shader_compiler';
import {InputInfo, ShapeInfo, UniformType} from './shader_compiler';
import {PackingScheme, TextureData, TextureUsage} from './tex_util';
Expand Down Expand Up @@ -47,7 +47,7 @@ export interface GPGPUProgram {
}

export interface GPGPUBinary {
webGLProgram: WebGLProgram;
webGLProgram: GPGPUContextProgram;
program: GPGPUProgram;
uniformLocations: {[name: string]: WebGLUniformLocation};
customUniformLocations?: WebGLUniformLocation[];
Expand Down
5 changes: 5 additions & 0 deletions tfjs-backend-webgl/src/webgl_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ export function linkProgram(gl: WebGLRenderingContext, program: WebGLProgram) {
}
}

/// validateProgram is effectively "If we `useProgram(program); drawArrays();`,
/// give feedback in log about perf/correctness warnings or errors that would
/// occur."
/// So make sure we set up all vertex/texture/sampler/uniform data before
/// calling validateProgram!
export function validateProgram(
gl: WebGLRenderingContext, program: WebGLProgram) {
callAndCheck(gl, () => gl.validateProgram(program));
Expand Down

0 comments on commit e9a9d4f

Please sign in to comment.