Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Return 0 for tf.sign(NaN) to align with TF #998

Merged
merged 5 commits into from
Apr 25, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/kernels/webgl/unaryop_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
* =============================================================================
*/

import * as selu_util from '../../ops/selu_util';
import * as erf_util from '../../ops/erf_util';
import * as selu_util from '../../ops/selu_util';

import {GPGPUProgram} from './gpgpu_math';

Expand Down Expand Up @@ -70,7 +70,10 @@ export const CEIL = `return ceil(x);`;

export const FLOOR = `return floor(x);`;

export const SIGN = `return sign(x);`;
export const SIGN = `
if (isNaN(x)) { return 0.0; }
return sign(x);
`;

export const ROUND = `
// OpenGL ES does not support round function.
Expand Down Expand Up @@ -175,15 +178,15 @@ export const ATANH = `return (log(1.0 + x) - log(1.0 - x)) / 2.0;`;

export const ERF = `
// Error function is calculated approximately with elementary function.
// See "Handbook of Mathematical Functions with Formulas,
// See "Handbook of Mathematical Functions with Formulas,
// Graphs, and Mathematical Tables", Abramowitz and Stegun.
float p = ${erf_util.ERF_P};
float a1 = ${erf_util.ERF_A1};
float a2 = ${erf_util.ERF_A2};
float a3 = ${erf_util.ERF_A3};
float a4 = ${erf_util.ERF_A4};
float a5 = ${erf_util.ERF_A5};

float t = 1.0 / (1.0 + p * x);
return 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
`;
Expand Down