diff --git a/src/kernels/webgl/webgl_custom_op_test.ts b/src/kernels/webgl/webgl_custom_op_test.ts index db75897a93..08b6945733 100644 --- a/src/kernels/webgl/webgl_custom_op_test.ts +++ b/src/kernels/webgl/webgl_custom_op_test.ts @@ -55,20 +55,18 @@ describeWithFlags('custom-op webgl', WEBGL_ENVS, () => { } function squareAndAdd(x: T): T { - const webglBackend = tf.ENV.backend as tf.webgl.MathBackendWebGL; - const program = new SquareAndAddKernel(x.shape); - const backpropProgram = new SquareAndAddBackpropKernel(x.shape); + const fn = tf.customGrad(x => { + const webglBackend = tf.ENV.backend as tf.webgl.MathBackendWebGL; + const program = new SquareAndAddKernel(x.shape); + const backpropProgram = new SquareAndAddBackpropKernel(x.shape); - const forward = () => webglBackend.compileAndRun(program, [x]); + const value = webglBackend.compileAndRun(program, [x]); - const backward = (dy: T) => { - return { - x: () => webglBackend.compileAndRun(backpropProgram, [x]).mul(dy) as T - }; - }; - - const res = tf.ENV.engine.runKernel(forward, {x}, backward); - return res as T; + const gradFunc = (dy: T) => + webglBackend.compileAndRun(backpropProgram, [x]).mul(dy) as T; + return {value, gradFunc}; + }); + return fn(x) as T; } it('lets users use custom operations', () => {