Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Commit

Permalink
Fix unsortedSegmentSum so it can call the Tensorflow backend (#1103)
Browse files Browse the repository at this point in the history
BUG
  • Loading branch information
jgartman authored and Nikhil Thorat committed Jun 14, 2018
1 parent 9b06592 commit aed24ef
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 49 deletions.
22 changes: 7 additions & 15 deletions src/kernels/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -337,30 +337,22 @@ export class MathBackendCPU implements KernelBackend {
unsortedSegmentSum<T extends Tensor>(
x: T, segmentIds: Tensor1D, numSegments: number): Tensor {
const res = [];
const [dim] = segmentIds.shape;
const axis = axis_util.getInnerMostAxes(1, x.rank)[0];

// Reshape the segment id's so that they can be broadcast with
// x. The new shape should be [1, 1, ... 1, dim, 1, ..., 1] where
// dim is at index = axis.
const newShape = [];
for (let i = 0; i < x.shape.length; ++i) {
if (i === axis) {
newShape.push(dim);
} else {
newShape.push(1);
}
// x. The new shape should be [segmentIds.shape, 1, ..., 1]
const numIters = x.rank - segmentIds.rank;
for (let i = 0; i < numIters; ++i) {
segmentIds = segmentIds.expandDims(i + 1);
}

const reshapedSegmentIds = ops.reshape(segmentIds, newShape);
for (let i = 0; i < numSegments; ++i) {
const segmentId = ops.scalar(i, 'int32');
const mask = ops.equal(segmentId, reshapedSegmentIds).asType('float32');
const sum = mask.mul(x).sum(axis);
const mask = ops.equal(segmentId, segmentIds).asType('float32');
const sum = mask.mul(x).sum(0);
res.push(sum);
}

return ops.stack(res, axis) as T;
return ops.stack(res);
}

argMin(x: Tensor, axis: number): Tensor {
Expand Down
28 changes: 20 additions & 8 deletions src/kernels/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -564,15 +564,27 @@ export class MathBackendWebGL implements KernelBackend {

unsortedSegmentSum<T extends Tensor>(
x: T, segmentIds: Tensor1D, numSegments: number): Tensor {
const axis = axis_util.getInnerMostAxes(1, x.rank)[0];
const outShape = segment_util.computeOutShape(x.shape, axis, numSegments);
const inSize = util.sizeFromShape([x.shape[axis]]);
const a2D = x.as2D(-1, inSize);
let axis = 0;
const permutation = axis_util.getAxesPermutation([axis], x.rank);
let permutedX = x;
if (permutation != null) {
permutedX = x.transpose(permutation);
axis = axis_util.getInnerMostAxes(1, x.rank)[0];
}

const outShape =
segment_util.computeOutShape(permutedX.shape, axis, numSegments);
const inSize = util.sizeFromShape([permutedX.shape[axis]]);
const a2D = permutedX.as2D(-1, inSize);
const outputDType = types.sumOutType(x.dtype);
return this
.segOpCompute(
a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments)
.reshape(outShape);
let result =
this.segOpCompute(
a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments)
.reshape(outShape);
if (permutation != null) {
result = result.transpose(axis_util.getUndoAxesPermutation(permutation));
}
return result;
}

private segOpCompute(
Expand Down
38 changes: 13 additions & 25 deletions src/ops/segment_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import {ENV} from '../environment';
import {Tensor, Tensor1D} from '../tensor';
import * as util from '../util';
import {ArrayOps} from './array_ops';
import * as axis_util from './axis_util';
import {BinaryOps} from './binary_ops';
import {CompareOps} from './compare';
import {LogicalOps} from './logical_ops';
Expand Down Expand Up @@ -52,45 +51,34 @@ export class SegmentOps {
segmentIds.dtype === 'int32', 'segmentIds must be of dtype `int32`');
util.assert(util.isInt(numSegments), 'numSegments must be of dtype int');

let axis = 0;
const permutation = axis_util.getAxesPermutation([axis], x.rank);
let permutedX = x;
if (permutation != null) {
permutedX = x.transpose(permutation);
axis = axis_util.getInnerMostAxes(1, x.rank)[0];
}
const gradFunc = (dy: T) => {
const derX = () => {
return gatherDropNegatives(dy, segmentIds, axis);
return gatherDropNegatives(dy, segmentIds);
};
return {permutedX: derX};
return {x: derX};
};
let result = ENV.engine.runKernel(
backend =>
backend.unsortedSegmentSum(permutedX, segmentIds, numSegments) as T,
{permutedX}, gradFunc);
if (permutation != null) {
result = result.transpose(axis_util.getUndoAxesPermutation(permutation));
}
return result;
return ENV.engine.runKernel(
backend =>
backend.unsortedSegmentSum(x, segmentIds, numSegments),
{x}, gradFunc) as T;
}
}

function gatherDropNegatives<T extends Tensor>(
x: T, indices: Tensor1D, axis: number) {
function gatherDropNegatives<T extends Tensor>(x: T, indices: Tensor1D) {
// Helper function for unsorted segment ops. Gathers params for
// positive segment ids and gathers 0 for inputs with negative segment id.
// Mirrors _GatherDropNegatives from tensorflow/python/ops/math_grad.py
const zeroClippedIndices =
BinaryOps.maximum(indices, ArrayOps.zerosLike(indices));
const gathered = ArrayOps.gather(x, zeroClippedIndices as Tensor1D, axis);
const gathered = ArrayOps.gather(x, zeroClippedIndices as Tensor1D);
let isPositive =
CompareOps.greaterEqual(indices, ArrayOps.scalar(0, 'int32'));
for (let i = 0; i < gathered.rank - isPositive.rank; ++i) {
isPositive = ArrayOps.expandDims(isPositive, -1);
const numIters = gathered.rank - isPositive.rank;
for (let i = 0; i < numIters; ++i) {
isPositive = ArrayOps.expandDims(isPositive, i + 1);
}
const bools = ArrayOps.onesLike(gathered).equal(ArrayOps.scalar(1));
isPositive = LogicalOps.logicalAnd(isPositive, bools);
isPositive =
LogicalOps.logicalAnd(isPositive, ArrayOps.ones(gathered.shape, 'bool'));
const zeroSlice = ArrayOps.zerosLike(gathered);
return LogicalOps.where(isPositive, gathered, zeroSlice);
}
1 change: 0 additions & 1 deletion src/ops/segment_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ describeWithFlags('unsortedSegmentSum', ALL_ENVS, () => {
const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [3, 2, 2]);
const segmentIds = tf.tensor1d([2, 1, 2], 'int32');
const numSegments = 3;
// const axis = 0;
const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);

expect(res.shape).toEqual([numSegments, 2, 2]);
Expand Down

0 comments on commit aed24ef

Please sign in to comment.