From ae174dcf02123e32f40a6f8f01d2152137b57508 Mon Sep 17 00:00:00 2001 From: Cosmin Ciocan Date: Tue, 14 Feb 2023 15:54:19 +0200 Subject: [PATCH 1/2] #7376 Add option to use seed with Orthogonal Initializer and use QR decomposition method instead of gramSchmidt for consistency with the python implementation --- tfjs-layers/src/initializers.ts | 54 +++++++++++++++++++--------- tfjs-layers/src/initializers_test.ts | 17 +++++++++ tfjs-layers/src/model_save_test.ts | 12 +++---- 3 files changed, 60 insertions(+), 23 deletions(-) diff --git a/tfjs-layers/src/initializers.ts b/tfjs-layers/src/initializers.ts index c2e7e01e62e..d7f1ee3d8c5 100644 --- a/tfjs-layers/src/initializers.ts +++ b/tfjs-layers/src/initializers.ts @@ -8,7 +8,7 @@ * ============================================================================= */ -import {DataType, eye, linalg, mul, ones, randomUniform, scalar, serialization, Tensor, Tensor2D, tidy, transpose, truncatedNormal, zeros} from '@tensorflow/tfjs-core'; +import {DataType, eye, linalg, mul, ones, randomUniform, scalar, serialization, Tensor, tidy, truncatedNormal, zeros} from '@tensorflow/tfjs-core'; import * as K from './backend/tfjs_backend'; import {checkDataFormat} from './common'; @@ -529,6 +529,7 @@ export class Orthogonal extends Initializer { /** @nocollapse */ static className = 'Orthogonal'; readonly DEFAULT_GAIN = 1; + readonly ELEMENTS_WARN_SLOW = 2000; protected readonly gain: number; protected readonly seed: number; @@ -536,11 +537,6 @@ export class Orthogonal extends Initializer { super(); this.gain = args.gain == null ? this.DEFAULT_GAIN : args.gain; this.seed = args.seed; - - if (this.seed != null) { - throw new NotImplementedError( - 'Random seed is not implemented for Orthogonal Initializer yet.'); - } } apply(shape: Shape, dtype?: DataType): Tensor { @@ -548,22 +544,46 @@ export class Orthogonal extends Initializer { if (shape.length < 2) { throw new NotImplementedError('Shape must be at least 2D.'); } - if (shape[0] * shape[1] > 2000) { + if (dtype !== 'int32' && dtype !== 'float32' && dtype !== undefined) { + throw new TypeError(`Unsupported data type ${dtype}.`); + } + dtype = dtype as 'int32' | 'float32' | undefined; + + // flatten the input shape with the last dimension remaining its + // original shape so it works for conv2d + let numRows = 1; + for (const dim of shape.slice(0, -1)) { + numRows *= dim; + } + const numCols = shape[shape.length - 1]; + const numElements = numRows * numCols; + if (numElements > this.ELEMENTS_WARN_SLOW) { console.warn( `Orthogonal initializer is being called on a matrix with more ` + - `than 2000 (${shape[0] * shape[1]}) elements: ` + + `than ${this.ELEMENTS_WARN_SLOW} (${numElements}) elements: ` + `Slowness may result.`); } - - // TODO(cais): Add seed support. - const normalizedShape = - shape[0] > shape[1] ? [shape[1], shape[0]] : shape; - const a = K.randomNormal(normalizedShape, 0, 1, 'float32') as Tensor2D; - let q = linalg.gramSchmidt(a) as Tensor2D; - if (shape[0] > shape[1]) { - q = transpose(q); + const maxDim = Math.max(numCols, numRows); + const minDim = Math.min(numCols, numRows); + const flatShape = [maxDim, minDim]; + + // Generate a random matrix + const randNormalMat = K.randomNormal(flatShape, 0, 1, dtype, this.seed); + + // Compute QR factorization + const qr = linalg.qr(randNormalMat, false); + let qMat = qr[0]; + const rMat = qr[1]; + + // Make Q uniform + const diag = + rMat.flatten().stridedSlice([0], [minDim * minDim], [minDim + 1]); + qMat = mul(qMat, diag.sign()); + if (numRows < numCols) { + qMat = qMat.transpose(); } - return mul(this.gain, q); + + return mul(scalar(this.gain), qMat.reshape(shape)); }); } diff --git a/tfjs-layers/src/initializers_test.ts b/tfjs-layers/src/initializers_test.ts index 14401c8ac1d..04e57a78dbe 100644 --- a/tfjs-layers/src/initializers_test.ts +++ b/tfjs-layers/src/initializers_test.ts @@ -831,4 +831,21 @@ describeMathCPUAndWebGL2('Orthogonal Initializer', () => { expect((model.predict(randomNormal([1, 128, 128, 1])) as Tensor).shape) .toEqual([1, 128, 128, 1]); }); + + it('with configured seed', () => { + const initializerConfig: serialization.ConfigDict = { + className: 'Orthogonal', + config: {seed: 666013} + }; + + const expectedInitializer = getInitializer(initializerConfig); + const actualInitializer = getInitializer(initializerConfig); + + const expected = expectedInitializer.apply([7, 2], 'float32'); + const actual = actualInitializer.apply([7, 2], 'float32'); + + expect(actual.shape).toEqual(expected.shape); + expect(actual.dtype).toEqual(expected.dtype); + expectTensorsClose(actual, expected); + }); }); diff --git a/tfjs-layers/src/model_save_test.ts b/tfjs-layers/src/model_save_test.ts index 1c8b2657a62..dba6af3fa6e 100644 --- a/tfjs-layers/src/model_save_test.ts +++ b/tfjs-layers/src/model_save_test.ts @@ -386,13 +386,13 @@ describeMathCPUAndWebGL2('Save-load round trips', () => { })); const modelJSON = model.toJSON(null, false); - const gramSchmidtSpy = spyOn(linalg, 'gramSchmidt').and.callThrough(); + const qrSpy = spyOn(linalg, 'qr').and.callThrough(); const modelPrime = await tfl.models.modelFromJSON({modelTopology: modelJSON}); // Make sure modelPrime builds. modelPrime.predict(zeros([2, 3, 4])); // Assert the orthogonal initializer has been called. - expect(gramSchmidtSpy).toHaveBeenCalled(); + expect(qrSpy).toHaveBeenCalled(); }); it('Partial non-strict load calls weight initializers', async () => { @@ -415,7 +415,7 @@ describeMathCPUAndWebGL2('Save-load round trips', () => { expect(savedArtifacts.weightSpecs.length).toEqual(3); savedArtifacts.weightSpecs = savedArtifacts.weightSpecs.slice(0, 1); - const gramSchmidtSpy = spyOn(linalg, 'gramSchmidt').and.callThrough(); + const qrSpy = spyOn(linalg, 'qr').and.callThrough(); const strict = false; const modelPrime = await tfl.loadLayersModel(io.fromMemory(savedArtifacts), {strict}); @@ -423,7 +423,7 @@ describeMathCPUAndWebGL2('Save-load round trips', () => { expect(weightsPrime.length).toEqual(weights.length); expectTensorsClose(weightsPrime[0], weights[0]); // Assert the orthogonal initializer has been called. - expect(gramSchmidtSpy).toHaveBeenCalled(); + expect(qrSpy).toHaveBeenCalled(); }); it('loadLayersModel: non-strict load calls weight initializers', async () => { @@ -446,7 +446,7 @@ describeMathCPUAndWebGL2('Save-load round trips', () => { expect(savedArtifacts.weightSpecs.length).toEqual(3); savedArtifacts.weightSpecs = savedArtifacts.weightSpecs.slice(0, 1); - const gramSchmidtSpy = spyOn(linalg, 'gramSchmidt').and.callThrough(); + const qrSpy = spyOn(linalg, 'qr').and.callThrough(); const strict = false; const modelPrime = await tfl.loadLayersModel(io.fromMemory(savedArtifacts), {strict}); @@ -454,7 +454,7 @@ describeMathCPUAndWebGL2('Save-load round trips', () => { expect(weightsPrime.length).toEqual(weights.length); expectTensorsClose(weightsPrime[0], weights[0]); // Assert the orthogonal initializer has been called. - expect(gramSchmidtSpy).toHaveBeenCalled(); + expect(qrSpy).toHaveBeenCalled(); }); it('Load model artifact with ndarray-format scalar objects', async () => { From 90e8a36de0f63c1addde3d23f7371f387f3a7822 Mon Sep 17 00:00:00 2001 From: Cosmin Ciocan Date: Wed, 15 Feb 2023 09:08:11 +0200 Subject: [PATCH 2/2] #7376 Use util.sizeFromShape method and remove intermediate variables in Orthogonal Initializer apply method --- tfjs-layers/src/initializers.ts | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/tfjs-layers/src/initializers.ts b/tfjs-layers/src/initializers.ts index d7f1ee3d8c5..369be829c45 100644 --- a/tfjs-layers/src/initializers.ts +++ b/tfjs-layers/src/initializers.ts @@ -8,7 +8,7 @@ * ============================================================================= */ -import {DataType, eye, linalg, mul, ones, randomUniform, scalar, serialization, Tensor, tidy, truncatedNormal, zeros} from '@tensorflow/tfjs-core'; +import {DataType, eye, linalg, mul, ones, randomUniform, scalar, serialization, Tensor, tidy, truncatedNormal, util, zeros} from '@tensorflow/tfjs-core'; import * as K from './backend/tfjs_backend'; import {checkDataFormat} from './common'; @@ -551,10 +551,7 @@ export class Orthogonal extends Initializer { // flatten the input shape with the last dimension remaining its // original shape so it works for conv2d - let numRows = 1; - for (const dim of shape.slice(0, -1)) { - numRows *= dim; - } + const numRows = util.sizeFromShape(shape.slice(0, -1)); const numCols = shape[shape.length - 1]; const numElements = numRows * numCols; if (numElements > this.ELEMENTS_WARN_SLOW) { @@ -563,9 +560,8 @@ export class Orthogonal extends Initializer { `than ${this.ELEMENTS_WARN_SLOW} (${numElements}) elements: ` + `Slowness may result.`); } - const maxDim = Math.max(numCols, numRows); - const minDim = Math.min(numCols, numRows); - const flatShape = [maxDim, minDim]; + const flatShape = + [Math.max(numCols, numRows), Math.min(numCols, numRows)]; // Generate a random matrix const randNormalMat = K.randomNormal(flatShape, 0, 1, dtype, this.seed); @@ -576,8 +572,9 @@ export class Orthogonal extends Initializer { const rMat = qr[1]; // Make Q uniform - const diag = - rMat.flatten().stridedSlice([0], [minDim * minDim], [minDim + 1]); + const diag = rMat.flatten().stridedSlice( + [0], [Math.min(numCols, numRows) * Math.min(numCols, numRows)], + [Math.min(numCols, numRows) + 1]); qMat = mul(qMat, diag.sign()); if (numRows < numCols) { qMat = qMat.transpose();