Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Commit

Permalink
Align pow gradient with TF (#1446)
Browse files Browse the repository at this point in the history
BUG
#### Description
<!--
Please describe the pull request here.
Also, if this is an issue/bug fix, please add the issue link for reference here.
-->
Multiple problems with the gradient for pow were brought up in tensorflow/tfjs#346.  One was fixed by #1376, the other is that the gradient with respect to the exponent is nan when the base is negative.  This differs from TF which returns 0 in this case: /~https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/math_grad.py#L1046

---
<!-- Please do not delete this section -->
##### For repository owners only:

Please remember to apply all applicable tags to your pull request.
Tags: FEATURE, BREAKING, BUG, PERF, DEV, DOC, SECURITY

For more info see: /~https://github.com/tensorflow/tfjs/blob/master/DEVELOPMENT.md

<!-- Reviewable:start -->
---
This change is [<img src="https://reviewable.io/review_button.svg" height="34" align="absmiddle" alt="Reviewable"/>](https://reviewable.io/reviews/tensorflow/tfjs-core/1446)
<!-- Reviewable:end -->
  • Loading branch information
jgartman authored and Nikhil Thorat committed Jan 4, 2019
1 parent d23adfe commit 1b475ac
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 19 deletions.
29 changes: 14 additions & 15 deletions src/ops/arithmetic_packed_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import {ALL_ENVS, expectArraysClose, expectArraysEqual} from '../test_util';

// TODO(/~https://github.com/tensorflow/tfjs/issues/1050):
// Remove this file as it contains a full copy of src/ops/arithmetic_test.js
// content with WEBGL_PACK_BINARY_OPERATIONS set for all tests. Once
// content with WEBGL_PACK_BINARY_OPERATIONS set for all tests. Once
// /~https://github.com/tensorflow/tfjs/issues/1050 is done, there is no further
// need for this file.
describeWithFlags('div', ALL_ENVS, () => {
Expand All @@ -33,8 +33,8 @@ describeWithFlags('div', ALL_ENVS, () => {
});

afterAll(() => {
tf.ENV.set('WEBGL_PACK_BINARY_OPERATIONS',
webglPackBinaryOperationsSavedFlag);
tf.ENV.set(
'WEBGL_PACK_BINARY_OPERATIONS', webglPackBinaryOperationsSavedFlag);
});

it('same shape', () => {
Expand Down Expand Up @@ -335,8 +335,8 @@ describeWithFlags('mul', ALL_ENVS, () => {
});

afterAll(() => {
tf.ENV.set('WEBGL_PACK_BINARY_OPERATIONS',
webglPackBinaryOperationsSavedFlag);
tf.ENV.set(
'WEBGL_PACK_BINARY_OPERATIONS', webglPackBinaryOperationsSavedFlag);
});

it('strict same-shaped tensors', () => {
Expand Down Expand Up @@ -641,8 +641,8 @@ describeWithFlags('pow', ALL_ENVS, () => {
});

afterAll(() => {
tf.ENV.set('WEBGL_PACK_BINARY_OPERATIONS',
webglPackBinaryOperationsSavedFlag);
tf.ENV.set(
'WEBGL_PACK_BINARY_OPERATIONS', webglPackBinaryOperationsSavedFlag);
});

it('same-shaped tensors', () => {
Expand Down Expand Up @@ -827,8 +827,7 @@ describeWithFlags('pow', ALL_ENVS, () => {
expect(db.shape).toEqual(b.shape);
expect(db.dtype).toEqual('float32');
expectArraysClose(db, [
NaN, 5 * Math.pow(.5, 2) * Math.log(.5),
10 * Math.pow(2, -1) * Math.log(2)
0, 5 * Math.pow(.5, 2) * Math.log(.5), 10 * Math.pow(2, -1) * Math.log(2)
]);
});

Expand Down Expand Up @@ -947,8 +946,8 @@ describeWithFlags('add', ALL_ENVS, () => {
});

afterAll(() => {
tf.ENV.set('WEBGL_PACK_BINARY_OPERATIONS',
webglPackBinaryOperationsSavedFlag);
tf.ENV.set(
'WEBGL_PACK_BINARY_OPERATIONS', webglPackBinaryOperationsSavedFlag);
});

it('c + A', () => {
Expand Down Expand Up @@ -1248,8 +1247,8 @@ describeWithFlags('addN', ALL_ENVS, () => {
});

afterAll(() => {
tf.ENV.set('WEBGL_PACK_BINARY_OPERATIONS',
webglPackBinaryOperationsSavedFlag);
tf.ENV.set(
'WEBGL_PACK_BINARY_OPERATIONS', webglPackBinaryOperationsSavedFlag);
});

it('a single tensor', () => {
Expand Down Expand Up @@ -1327,8 +1326,8 @@ describeWithFlags('sub', ALL_ENVS, () => {
});

afterAll(() => {
tf.ENV.set('WEBGL_PACK_BINARY_OPERATIONS',
webglPackBinaryOperationsSavedFlag);
tf.ENV.set(
'WEBGL_PACK_BINARY_OPERATIONS', webglPackBinaryOperationsSavedFlag);
});

it('c - A', () => {
Expand Down
16 changes: 14 additions & 2 deletions src/ops/arithmetic_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -813,11 +813,23 @@ describeWithFlags('pow', ALL_ENVS, () => {
expect(db.shape).toEqual(b.shape);
expect(db.dtype).toEqual('float32');
expectArraysClose(db, [
NaN, 5 * Math.pow(.5, 2) * Math.log(.5),
10 * Math.pow(2, -1) * Math.log(2)
0, 5 * Math.pow(.5, 2) * Math.log(.5), 10 * Math.pow(2, -1) * Math.log(2)
]);
});

it('gradient wrt exponent with negative base', () => {
const a = tf.tensor1d([-1, -.5, -2.7]);
const b = tf.tensor1d([3, 2, -1], 'int32');
const dy = tf.tensor1d([1, 1, 1]);

const grads = tf.grads((a, b) => tf.pow(a, b));
const [, db] = grads([a, b], dy);

expect(db.shape).toEqual(b.shape);
expect(db.dtype).toEqual('float32');
expectArraysClose(db, [0, 0, 0]);
});

it('gradient: scalar / Tensor1D', () => {
const a = tf.scalar(2);
const b = tf.tensor1d([3, 4, 5]);
Expand Down
6 changes: 4 additions & 2 deletions src/ops/binary_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import {TensorLike, upcastType} from '../types';
import * as util from '../util';
import * as broadcast_util from './broadcast_util';
import {op} from './operation';
import {scalar} from './tensor_ops';
import {scalar, zerosLike} from './tensor_ops';
import {neg} from './unary_ops';

/**
Expand Down Expand Up @@ -265,7 +265,9 @@ function pow_<T extends Tensor>(base: T|TensorLike, exp: Tensor|TensorLike): T {
return res.reshape($base.shape) as T;
};
const derExp = () => {
let res = dy.mul(y.mul($base.log()).toFloat());
const condition = $base.greater(0);
const logBase = $base.log().where(condition, zerosLike($base));
let res = dy.mul(y.mul(logBase));
const reduceAxes = broadcast_util.getReductionAxes($exp.shape, outShape);
if (reduceAxes.length > 0) {
res = res.sum(reduceAxes);
Expand Down

0 comments on commit 1b475ac

Please sign in to comment.