Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Revive packed division - add out of bounds check. #1660

Merged
merged 13 commits into from
Apr 3, 2019
16 changes: 9 additions & 7 deletions src/kernels/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1371,11 +1371,11 @@ export class MathBackendWebGL implements KernelBackend {
realDivide(a: Tensor, b: Tensor): Tensor {
const op = binaryop_gpu.DIV;
const outputDtype = 'float32';
// TODO: /~https://github.com/tensorflow/tfjs/issues/1324
// Revive this once we understand why this produces NaNs.
// if (ENV.get('WEBGL_PACK_BINARY_OPERATIONS')) {
// return this.packedBinaryOp(a, b, binaryop_packed_gpu.DIV, outputDtype);
// }
if (ENV.get('WEBGL_PACK_BINARY_OPERATIONS')) {
const checkOutOfBounds = true;
return this.packedBinaryOp(
a, b, binaryop_packed_gpu.DIV, outputDtype, checkOutOfBounds);
}
const program = new BinaryOpProgram(op, a.shape, b.shape);
const output = this.makeOutputArray(program.outputShape, outputDtype);
return this.compileAndRun<Tensor>(program, [a, b], output);
Expand Down Expand Up @@ -1412,8 +1412,10 @@ export class MathBackendWebGL implements KernelBackend {
}

private packedBinaryOp(
a: TensorHandle, b: TensorHandle, op: string, dtype: DataType) {
const program = new BinaryOpPackedProgram(op, a.shape, b.shape);
a: TensorHandle, b: TensorHandle, op: string, dtype: DataType,
checkOutOfBounds = false) {
const program =
new BinaryOpPackedProgram(op, a.shape, b.shape, checkOutOfBounds);
const output = this.makePackedTensor(program.outputShape, dtype) as Tensor;
return this.compileAndRun<Tensor>(program, [a, b], output);
}
Expand Down
10 changes: 8 additions & 2 deletions src/kernels/webgl/binaryop_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@ export const MUL = 'return a * b;';

// Without the equality check div produces 0.9999 for a = b, which when
// floored can cause errors.
export const DIV = `if (a == b) return 1.0;
return a / b;`;
export const DIV = `
if (b == 0.0) {
return NAN;
}
if (a == b) {
return 1.0;
};
return a / b;`;

// We use native integer division to deal with floating point imprecision. Since
// we implement floor division and glsl implements truncated division, we
Expand Down
75 changes: 69 additions & 6 deletions src/kernels/webgl/binaryop_packed_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
*/

import * as broadcast_util from '../../ops/broadcast_util';
import {sizeFromShape} from '../../util';
import {getChannels} from '../packing_util';

import {GPGPUProgram} from './gpgpu_math';
import {getCoordsDataType} from './shader_compiler';

const CHECK_NAN_SNIPPET = `
result.r = isNaN.r > 0. ? NAN : result.r;
Expand All @@ -31,10 +35,27 @@ export const DIV = `
// vec4 one = vec4(equal(a, b));
// return one + (vec4(1.0) - one) * a / b;
vec4 result = a / b;
result.x = a.x == b.x ? 1. : result.x;
result.y = a.y == b.y ? 1. : result.y;
result.z = a.z == b.z ? 1. : result.z;
result.w = a.w == b.w ? 1. : result.w;
if(b.x == 0.0) {
result.x = NAN;
} else if(a.x == b.x) {
result.x = 1.;
}
if(b.y == 0.0) {
result.y = NAN;
} else if(a.y == b.y) {
result.y = 1.;
}
if(b.z == 0.0) {
result.z = NAN;
} else if(a.z == b.z) {
result.z = 1.;
}
if(b.w == 0.0) {
result.w = NAN;
} else if(a.w == b.w) {
result.w = 1.;
}

return result;
`;

Expand Down Expand Up @@ -159,9 +180,47 @@ export class BinaryOpPackedProgram implements GPGPUProgram {
supportsBroadcasting = true;
usesPackedTextures = true;

constructor(op: string, aShape: number[], bShape: number[]) {
constructor(
op: string, aShape: number[], bShape: number[],
checkOutOfBounds = false) {
this.outputShape =
broadcast_util.assertAndGetBroadcastShape(aShape, bShape);
const rank = this.outputShape.length;
let checkOutOfBoundsString = '';
if (checkOutOfBounds) {
if (rank === 0 || sizeFromShape(this.outputShape) === 1) {
checkOutOfBoundsString = `
result.y = 0.;
result.z = 0.;
result.w = 0.;
`;
} else {
const dtype = getCoordsDataType(rank);
checkOutOfBoundsString = `
${dtype} coords = getOutputCoords();
`;
if (rank === 1) {
checkOutOfBoundsString += `
result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y;
result.z = 0.;
result.w = 0.;
`;
} else {
const channels = getChannels('coords', rank);
checkOutOfBoundsString += `
bool nextRowOutOfBounds =
(${channels[rank - 2]} + 1) >= ${this.outputShape[rank - 2]};
bool nextColOutOfBounds =
(${channels[rank - 1]} + 1) >= ${this.outputShape[rank - 1]};

result.y = nextColOutOfBounds ? 0. : result.y;
result.z = nextRowOutOfBounds ? 0. : result.z;
result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
`;
}
}
}

this.userCode = `
vec4 binaryOperation(vec4 a, vec4 b) {
${op}
Expand All @@ -170,7 +229,11 @@ export class BinaryOpPackedProgram implements GPGPUProgram {
void main() {
vec4 a = getAAtOutCoords();
vec4 b = getBAtOutCoords();
setOutput(binaryOperation(a, b));
vec4 result = binaryOperation(a, b);

${checkOutOfBoundsString}

setOutput(result);
}
`;
}
Expand Down
24 changes: 20 additions & 4 deletions src/ops/binary_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,27 @@ describeWithFlags('div', PACKED_ENVS, () => {
it('works when unused channels are divided', () => {
// Tests that the 0's in unused channels for input textures do not corrupt
// the result when swizzled with 3 / 3.
const a = tf.tensor2d([3], [1, 1]);
const b = tf.tensor2d([3], [1, 1]);
const a = tf.tensor2d([1], [1, 1]);
const b = tf.tensor2d([1], [1, 1]);

const c = a.div(b).matMul(b);
expectArraysClose(c, [3]);
const c = tf.add(a, b).div(a);
const d = tf.add(a, b).div(a);

const result = c.matMul(d);
expectArraysClose(result, [4]);
});

it('works when unused channels in tensors with size > 1 are divided', () => {
const a = tf.tensor2d([1, 2, 3], [3, 1]);
const b = tf.tensor2d([1, 2, 3], [3, 1]);
const c = a.div(b);

const d = tf.tensor1d([1, 2, 3]);
const e = tf.tensor1d([1, 2, 3]);
const f = d.div(e).reshape([1, 3]);

const result = c.matMul(f);
expectArraysClose(result, [1, 1, 1, 1, 1, 1, 1, 1, 1]);
});
});

Expand Down
2 changes: 1 addition & 1 deletion src/test_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ export function expectNumbersClose(a: number, e: number, epsilon?: number) {
}

function areClose(a: number, e: number, epsilon: number): boolean {
if (isNaN(a) && isNaN(e)) {
if (!isFinite(a) && !isFinite(e)) {
return true;
}
if (isNaN(a) || isNaN(e) || Math.abs(a - e) > epsilon) {
Expand Down