From aed24ef2d17cf990352de3befa768aad1443b03f Mon Sep 17 00:00:00 2001 From: Josh Gartman Date: Thu, 14 Jun 2018 14:46:27 -0400 Subject: [PATCH] Fix unsortedSegmentSum so it can call the Tensorflow backend (#1103) BUG --- src/kernels/backend_cpu.ts | 22 +++++++-------------- src/kernels/backend_webgl.ts | 28 ++++++++++++++++++-------- src/ops/segment_ops.ts | 38 ++++++++++++------------------------ src/ops/segment_ops_test.ts | 1 - 4 files changed, 40 insertions(+), 49 deletions(-) diff --git a/src/kernels/backend_cpu.ts b/src/kernels/backend_cpu.ts index 80afaffdc3..1572f0cca2 100644 --- a/src/kernels/backend_cpu.ts +++ b/src/kernels/backend_cpu.ts @@ -337,30 +337,22 @@ export class MathBackendCPU implements KernelBackend { unsortedSegmentSum( x: T, segmentIds: Tensor1D, numSegments: number): Tensor { const res = []; - const [dim] = segmentIds.shape; - const axis = axis_util.getInnerMostAxes(1, x.rank)[0]; // Reshape the segment id's so that they can be broadcast with - // x. The new shape should be [1, 1, ... 1, dim, 1, ..., 1] where - // dim is at index = axis. - const newShape = []; - for (let i = 0; i < x.shape.length; ++i) { - if (i === axis) { - newShape.push(dim); - } else { - newShape.push(1); - } + // x. The new shape should be [segmentIds.shape, 1, ..., 1] + const numIters = x.rank - segmentIds.rank; + for (let i = 0; i < numIters; ++i) { + segmentIds = segmentIds.expandDims(i + 1); } - const reshapedSegmentIds = ops.reshape(segmentIds, newShape); for (let i = 0; i < numSegments; ++i) { const segmentId = ops.scalar(i, 'int32'); - const mask = ops.equal(segmentId, reshapedSegmentIds).asType('float32'); - const sum = mask.mul(x).sum(axis); + const mask = ops.equal(segmentId, segmentIds).asType('float32'); + const sum = mask.mul(x).sum(0); res.push(sum); } - return ops.stack(res, axis) as T; + return ops.stack(res); } argMin(x: Tensor, axis: number): Tensor { diff --git a/src/kernels/backend_webgl.ts b/src/kernels/backend_webgl.ts index 46a4ae353e..f2e227d8dc 100644 --- a/src/kernels/backend_webgl.ts +++ b/src/kernels/backend_webgl.ts @@ -564,15 +564,27 @@ export class MathBackendWebGL implements KernelBackend { unsortedSegmentSum( x: T, segmentIds: Tensor1D, numSegments: number): Tensor { - const axis = axis_util.getInnerMostAxes(1, x.rank)[0]; - const outShape = segment_util.computeOutShape(x.shape, axis, numSegments); - const inSize = util.sizeFromShape([x.shape[axis]]); - const a2D = x.as2D(-1, inSize); + let axis = 0; + const permutation = axis_util.getAxesPermutation([axis], x.rank); + let permutedX = x; + if (permutation != null) { + permutedX = x.transpose(permutation); + axis = axis_util.getInnerMostAxes(1, x.rank)[0]; + } + + const outShape = + segment_util.computeOutShape(permutedX.shape, axis, numSegments); + const inSize = util.sizeFromShape([permutedX.shape[axis]]); + const a2D = permutedX.as2D(-1, inSize); const outputDType = types.sumOutType(x.dtype); - return this - .segOpCompute( - a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments) - .reshape(outShape); + let result = + this.segOpCompute( + a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments) + .reshape(outShape); + if (permutation != null) { + result = result.transpose(axis_util.getUndoAxesPermutation(permutation)); + } + return result; } private segOpCompute( diff --git a/src/ops/segment_ops.ts b/src/ops/segment_ops.ts index f76071df68..41d3e9c213 100644 --- a/src/ops/segment_ops.ts +++ b/src/ops/segment_ops.ts @@ -20,7 +20,6 @@ import {ENV} from '../environment'; import {Tensor, Tensor1D} from '../tensor'; import * as util from '../util'; import {ArrayOps} from './array_ops'; -import * as axis_util from './axis_util'; import {BinaryOps} from './binary_ops'; import {CompareOps} from './compare'; import {LogicalOps} from './logical_ops'; @@ -52,45 +51,34 @@ export class SegmentOps { segmentIds.dtype === 'int32', 'segmentIds must be of dtype `int32`'); util.assert(util.isInt(numSegments), 'numSegments must be of dtype int'); - let axis = 0; - const permutation = axis_util.getAxesPermutation([axis], x.rank); - let permutedX = x; - if (permutation != null) { - permutedX = x.transpose(permutation); - axis = axis_util.getInnerMostAxes(1, x.rank)[0]; - } const gradFunc = (dy: T) => { const derX = () => { - return gatherDropNegatives(dy, segmentIds, axis); + return gatherDropNegatives(dy, segmentIds); }; - return {permutedX: derX}; + return {x: derX}; }; - let result = ENV.engine.runKernel( - backend => - backend.unsortedSegmentSum(permutedX, segmentIds, numSegments) as T, - {permutedX}, gradFunc); - if (permutation != null) { - result = result.transpose(axis_util.getUndoAxesPermutation(permutation)); - } - return result; + return ENV.engine.runKernel( + backend => + backend.unsortedSegmentSum(x, segmentIds, numSegments), + {x}, gradFunc) as T; } } -function gatherDropNegatives( - x: T, indices: Tensor1D, axis: number) { +function gatherDropNegatives(x: T, indices: Tensor1D) { // Helper function for unsorted segment ops. Gathers params for // positive segment ids and gathers 0 for inputs with negative segment id. // Mirrors _GatherDropNegatives from tensorflow/python/ops/math_grad.py const zeroClippedIndices = BinaryOps.maximum(indices, ArrayOps.zerosLike(indices)); - const gathered = ArrayOps.gather(x, zeroClippedIndices as Tensor1D, axis); + const gathered = ArrayOps.gather(x, zeroClippedIndices as Tensor1D); let isPositive = CompareOps.greaterEqual(indices, ArrayOps.scalar(0, 'int32')); - for (let i = 0; i < gathered.rank - isPositive.rank; ++i) { - isPositive = ArrayOps.expandDims(isPositive, -1); + const numIters = gathered.rank - isPositive.rank; + for (let i = 0; i < numIters; ++i) { + isPositive = ArrayOps.expandDims(isPositive, i + 1); } - const bools = ArrayOps.onesLike(gathered).equal(ArrayOps.scalar(1)); - isPositive = LogicalOps.logicalAnd(isPositive, bools); + isPositive = + LogicalOps.logicalAnd(isPositive, ArrayOps.ones(gathered.shape, 'bool')); const zeroSlice = ArrayOps.zerosLike(gathered); return LogicalOps.where(isPositive, gathered, zeroSlice); } diff --git a/src/ops/segment_ops_test.ts b/src/ops/segment_ops_test.ts index f1dcedfe32..faaa26c0ab 100644 --- a/src/ops/segment_ops_test.ts +++ b/src/ops/segment_ops_test.ts @@ -45,7 +45,6 @@ describeWithFlags('unsortedSegmentSum', ALL_ENVS, () => { const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [3, 2, 2]); const segmentIds = tf.tensor1d([2, 1, 2], 'int32'); const numSegments = 3; - // const axis = 0; const res = tf.unsortedSegmentSum(t, segmentIds, numSegments); expect(res.shape).toEqual([numSegments, 2, 2]);