Skip to content

Commit

Permalink
webgpu: Implement isnan (#6075)
Browse files Browse the repository at this point in the history
Wgsl removes inNaN from spec (gpuweb/gpuweb#2311)
This CL implement isnan based on the rules in IEEE 754-1985
  • Loading branch information
shaoboyan authored Feb 18, 2022
1 parent a51929e commit aa2c7e1
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 23 deletions.
2 changes: 1 addition & 1 deletion tfjs-backend-webgpu/src/argminmax_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ export class ArgMinMaxProgram implements WebGPUProgram {
for (var k = i32(localId.x); k < Length && outputIndex < uniforms.size;
k = k + i32(workGroupSizeX)) {
let candidate = f32(x.numbers[getInputIndex(coordInfo, k)]);
if (!isNanCustom(candidate) && candidate ${this.op} bestValue) {
if (!isnan(candidate) && candidate ${this.op} bestValue) {
bestValue = candidate;
bestIndex = k;
}
Expand Down
6 changes: 3 additions & 3 deletions tfjs-backend-webgpu/src/binary_op_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ const LOGICAL_AND = 'return f32(f32(a) >= 1.0 && f32(b) >= 1.0);';
const LOGICAL_AND_VEC4 = `return (vec4<f32>(a >= vec4<f32>(1.0)) *
vec4<f32>(b >= vec4<f32>(1.0)));`;
const CHECK_NAN_SNIPPET = `
if (isNanCustom(a)) { return a; }
if (isNanCustom(b)) { return b; }
if (isnan(a)) { return a; }
if (isnan(b)) { return b; }
`;
const CHECK_NAN_SNIPPET_VEC4 = `
if (isNaN.r) {
Expand Down Expand Up @@ -158,7 +158,7 @@ function getMinMaxString(op: string, useVec4: boolean) {
const checkNanSnippet = useVec4 ? CHECK_NAN_SNIPPET_VEC4 : CHECK_NAN_SNIPPET;
return useVec4 ? `
var resultTemp = vec4<f32>(${op}(a, b));
let isNaN = isNanCustomVec4(a) | isNanCustomVec4(b);
let isNaN = isnanVec4(a) | isnanVec4(b);
` + checkNanSnippet +
`
return resultTemp;
Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-webgpu/src/clip_vec4_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export class ClipVec4Program implements WebGPUProgram {
let value = getAByOutputIndex(index);
var clampedValue : vec4<f32>;
for (var i = 0; i < 4; i = i + 1) {
if (isNanCustom(value[i])) {
if (isnan(value[i])) {
clampedValue[i] = value[i];
} else {
clampedValue[i] = clamp(value[i], uniforms.minVal, uniforms.maxVal);
Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-webgpu/src/clip_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export class ClipProgram implements WebGPUProgram {
${getMainHeaderAndGlobalIndexString()}
if(index < uniforms.size) {
let value = getAByOutputIndex(index);
if (isNanCustom(value)) {
if (isnan(value)) {
setOutputAtIndex(index, value);
return;
}
Expand Down
4 changes: 2 additions & 2 deletions tfjs-backend-webgpu/src/reduce_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ export class ReduceProgram implements WebGPUProgram {
let initValue = '0.0';
if (this.reduceType === 'min' || this.reduceType === 'max') {
reduceOp = `
if (isNanCustom(candidate)) {
if (isnan(candidate)) {
bestValue = uniforms.NAN;
} else if (!isNanCustom(bestValue) && candidate ${
} else if (!isnan(bestValue) && candidate ${
this.reduceType === 'min' ? '<' : '>'} bestValue)
{ bestValue = candidate; }`;
initValue = 'f32(x.numbers[offset])';
Expand Down
25 changes: 11 additions & 14 deletions tfjs-backend-webgpu/src/shader_preprocessor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -269,20 +269,17 @@ const commonSnippet = `
return res;
}
fn isNanCustom(val : f32) -> bool {
if (val > 0.0) {
return false;
}
if (val < 0.0) {
return false;
}
if (val == 0.0) {
return false;
}
return true;
}
fn isNanCustomVec4(val : vec4<f32>) -> vec4<bool> {
return vec4<bool>(isNanCustom(val[0]), isNanCustom(val[1]), isNanCustom(val[2]), isNanCustom(val[3]));
// NaN defination in IEEE 754-1985 is :
// - sign = either 0 or 1.
// - biased exponent = all 1 bits.
// - fraction = anything except all 0 bits (since all 0 bits represents infinity).
// https://en.wikipedia.org/wiki/IEEE_754-1985#Representation_of_non-numbers
fn isnan(val: f32) -> bool {
let floatToUint: u32 = bitcast<u32>(val);
return (floatToUint & 0x7fffffffu) > 0x7f800000u;
}
fn isnanVec4(val : vec4<f32>) -> vec4<bool> {
return vec4<bool>(isnan(val[0]), isnan(val[1]), isnan(val[2]), isnan(val[3]));
}
`;

Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-webgpu/src/unary_op_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ const RELU6_VEC4 =
'return clamp(a, vec4<f32>(0.0, 0.0, 0.0, 0.0), vec4<f32>(6.0, 6.0, 6.0, 6.0));';
const RELU_VEC4 = `
var resFloat = a * vec4<f32>(a >= vec4<f32>(0.0));
let isNaN = isNanCustomVec4(a);
let isNaN = isnanVec4(a);
if (isNaN.r) {
resFloat.r = a.r;
Expand Down

0 comments on commit aa2c7e1

Please sign in to comment.