Skip to content

Commit

Permalink
#7376 Use util.sizeFromShape method and remove intermediate variables…
Browse files Browse the repository at this point in the history
… in Orthogonal Initializer apply method
  • Loading branch information
Cosmin Ciocan committed Feb 15, 2023
1 parent ae174dc commit 3cb4348
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions tfjs-layers/src/initializers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
* =============================================================================
*/

import {DataType, eye, linalg, mul, ones, randomUniform, scalar, serialization, Tensor, tidy, truncatedNormal, zeros} from '@tensorflow/tfjs-core';
import {DataType, eye, linalg, mul, ones, randomUniform, scalar, serialization, Tensor, tidy, truncatedNormal, util, zeros} from '@tensorflow/tfjs-core';

import * as K from './backend/tfjs_backend';
import {checkDataFormat} from './common';
Expand Down Expand Up @@ -551,10 +551,7 @@ export class Orthogonal extends Initializer {

// flatten the input shape with the last dimension remaining its
// original shape so it works for conv2d
let numRows = 1;
for (const dim of shape.slice(0, -1)) {
numRows *= dim;
}
const numRows = util.sizeFromShape(shape.slice(0, -1));
const numCols = shape[shape.length - 1];
const numElements = numRows * numCols;
if (numElements > this.ELEMENTS_WARN_SLOW) {
Expand All @@ -563,9 +560,8 @@ export class Orthogonal extends Initializer {
`than ${this.ELEMENTS_WARN_SLOW} (${numElements}) elements: ` +
`Slowness may result.`);
}
const maxDim = Math.max(numCols, numRows);
const minDim = Math.min(numCols, numRows);
const flatShape = [maxDim, minDim];
const flatShape =
[Math.max(numCols, numRows), Math.min(numCols, numRows)];

// Generate a random matrix
const randNormalMat = K.randomNormal(flatShape, 0, 1, dtype, this.seed);
Expand All @@ -576,8 +572,9 @@ export class Orthogonal extends Initializer {
const rMat = qr[1];

// Make Q uniform
const diag =
rMat.flatten().stridedSlice([0], [minDim * minDim], [minDim + 1]);
const diag = rMat.flatten().stridedSlice(
[0], [Math.min(numCols, numRows) * Math.min(numCols, numRows)],
[Math.min(numCols, numRows) + 1]);
qMat = mul(qMat, diag.sign());
if (numRows < numCols) {
qMat = qMat.transpose();
Expand Down

0 comments on commit 3cb4348

Please sign in to comment.